Initial commit: add .gitignore and README
This commit is contained in:
39
fusionagi/__init__.py
Normal file
39
fusionagi/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""FusionAGI: the world's most advanced agentic AGI orchestration framework."""
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi.core import EventBus, Orchestrator, StateManager
|
||||
from fusionagi.schemas import AgentMessageEnvelope, Task
|
||||
from fusionagi.self_improvement import (
|
||||
SelfCorrectionLoop,
|
||||
AutoRecommender,
|
||||
AutoTrainer,
|
||||
FusionAGILoop,
|
||||
)
|
||||
|
||||
|
||||
def __get_version() -> str:
|
||||
"""Single source of version from package metadata (pyproject.toml)."""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
return version("fusionagi")
|
||||
except Exception:
|
||||
return "0.1.0"
|
||||
|
||||
|
||||
__version__ = __get_version()
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"logger",
|
||||
"EventBus",
|
||||
"Orchestrator",
|
||||
"StateManager",
|
||||
"Task",
|
||||
"AgentMessageEnvelope",
|
||||
"SelfCorrectionLoop",
|
||||
"AutoRecommender",
|
||||
"AutoTrainer",
|
||||
"FusionAGILoop",
|
||||
]
|
||||
|
||||
# Interface layer is available via: from fusionagi.interfaces import ...
|
||||
5
fusionagi/_logger.py
Normal file
5
fusionagi/_logger.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Shared logger for the fusionagi package; no other fusionagi imports to avoid circular imports."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("fusionagi")
|
||||
13
fusionagi/_time.py
Normal file
13
fusionagi/_time.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Shared time utilities; timezone-aware UTC to avoid deprecated datetime.utcnow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Return current UTC time (timezone-aware). Prefer over datetime.utcnow()."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
"""Return current UTC time as ISO format string."""
|
||||
return utc_now().isoformat()
|
||||
18
fusionagi/adapters/__init__.py
Normal file
18
fusionagi/adapters/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""LLM adapters: abstract interface and provider implementations.
|
||||
|
||||
NativeAdapter: Uses FusionAGI's internal reasoning—no external API calls.
|
||||
OpenAIAdapter is None when the openai package is not installed (pip install fusionagi[openai]).
|
||||
Use: from fusionagi.adapters import OpenAIAdapter; if OpenAIAdapter is not None: ...
|
||||
"""
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.adapters.stub_adapter import StubAdapter
|
||||
from fusionagi.adapters.cache import CachedAdapter
|
||||
from fusionagi.adapters.native_adapter import NativeAdapter
|
||||
|
||||
try:
|
||||
from fusionagi.adapters.openai_adapter import OpenAIAdapter
|
||||
except ImportError:
|
||||
OpenAIAdapter = None # type: ignore[misc, assignment]
|
||||
|
||||
__all__ = ["LLMAdapter", "StubAdapter", "CachedAdapter", "NativeAdapter", "OpenAIAdapter"]
|
||||
55
fusionagi/adapters/base.py
Normal file
55
fusionagi/adapters/base.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Abstract LLM adapter interface; model-agnostic for orchestrator and agents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LLMAdapter(ABC):
|
||||
"""
|
||||
Abstract adapter for LLM completion.
|
||||
|
||||
Implementations should handle:
|
||||
- openai/ - OpenAI API (GPT-4, etc.)
|
||||
- anthropic/ - Anthropic API (Claude, etc.)
|
||||
- local/ - Local models (Ollama, etc.)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Return completion text for the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys.
|
||||
**kwargs: Provider-specific options (e.g., temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
The model's response text.
|
||||
"""
|
||||
...
|
||||
|
||||
def complete_structured(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
schema: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Return structured (JSON) output.
|
||||
|
||||
Default implementation returns None; subclasses may override to use
|
||||
provider-specific JSON modes (e.g., OpenAI's response_format).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys.
|
||||
schema: Optional JSON schema for response validation.
|
||||
**kwargs: Provider-specific options.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response or None if not supported/parsing fails.
|
||||
"""
|
||||
return None
|
||||
115
fusionagi/adapters/cache.py
Normal file
115
fusionagi/adapters/cache.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Optional response cache for LLM adapter."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
|
||||
|
||||
class CachedAdapter(LLMAdapter):
|
||||
"""
|
||||
Wraps an adapter and caches responses by messages hash.
|
||||
|
||||
Features:
|
||||
- Caches both complete() and complete_structured() responses
|
||||
- LRU eviction when at capacity (most recently used retained)
|
||||
- Separate caches for text and structured responses
|
||||
- Cache statistics for monitoring
|
||||
"""
|
||||
|
||||
def __init__(self, adapter: LLMAdapter, max_entries: int = 100) -> None:
|
||||
"""
|
||||
Initialize the cached adapter.
|
||||
|
||||
Args:
|
||||
adapter: The underlying LLM adapter to wrap.
|
||||
max_entries: Maximum cache entries before eviction.
|
||||
"""
|
||||
self._adapter = adapter
|
||||
self._cache: OrderedDict[str, str] = OrderedDict()
|
||||
self._structured_cache: OrderedDict[str, Any] = OrderedDict()
|
||||
self._max_entries = max_entries
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
def _key(self, messages: list[dict[str, str]], kwargs: dict[str, Any], prefix: str = "") -> str:
|
||||
"""Generate a cache key from messages and kwargs."""
|
||||
payload = json.dumps(
|
||||
{"prefix": prefix, "messages": messages, "kwargs": kwargs},
|
||||
sort_keys=True,
|
||||
default=str,
|
||||
)
|
||||
return hashlib.sha256(payload.encode()).hexdigest()
|
||||
|
||||
def _evict_if_needed(self, cache: OrderedDict[str, Any]) -> None:
|
||||
"""Evict least recently used entry if cache is at capacity."""
|
||||
while len(cache) >= self._max_entries and cache:
|
||||
cache.popitem(last=False)
|
||||
|
||||
def _get_and_touch(self, cache: OrderedDict[str, Any], key: str) -> Any:
|
||||
"""Get value and move to end (LRU touch)."""
|
||||
val = cache[key]
|
||||
cache.move_to_end(key)
|
||||
return val
|
||||
|
||||
def complete(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
||||
"""Complete with caching."""
|
||||
key = self._key(messages, kwargs, prefix="complete")
|
||||
if key in self._cache:
|
||||
self._hits += 1
|
||||
return self._get_and_touch(self._cache, key)
|
||||
|
||||
self._misses += 1
|
||||
response = self._adapter.complete(messages, **kwargs)
|
||||
self._evict_if_needed(self._cache)
|
||||
self._cache[key] = response
|
||||
return response
|
||||
|
||||
def complete_structured(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
schema: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Complete structured with caching.
|
||||
|
||||
Caches structured responses separately from text responses.
|
||||
"""
|
||||
cache_kwargs = {**kwargs, "_schema": schema}
|
||||
key = self._key(messages, cache_kwargs, prefix="structured")
|
||||
|
||||
if key in self._structured_cache:
|
||||
self._hits += 1
|
||||
return self._get_and_touch(self._structured_cache, key)
|
||||
|
||||
self._misses += 1
|
||||
response = self._adapter.complete_structured(messages, schema=schema, **kwargs)
|
||||
|
||||
if response is not None:
|
||||
self._evict_if_needed(self._structured_cache)
|
||||
self._structured_cache[key] = response
|
||||
|
||||
return response
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Return cache statistics."""
|
||||
total = self._hits + self._misses
|
||||
hit_rate = self._hits / total if total > 0 else 0.0
|
||||
return {
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"hit_rate": hit_rate,
|
||||
"text_cache_size": len(self._cache),
|
||||
"structured_cache_size": len(self._structured_cache),
|
||||
"max_entries": self._max_entries,
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached responses."""
|
||||
self._cache.clear()
|
||||
self._structured_cache.clear()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
101
fusionagi/adapters/native_adapter.py
Normal file
101
fusionagi/adapters/native_adapter.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Native adapter: implements LLMAdapter using FusionAGI's internal reasoning.
|
||||
|
||||
No external API calls. Used for synthesis (e.g. Witness compose) when operating
|
||||
in fully native AGI mode.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
|
||||
|
||||
def _synthesize_from_messages(messages: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
Synthesize narrative from message content using native logic only.
|
||||
Extracts head summaries and agreed claims, produces coherent narrative.
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
content_parts: list[str] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str) and content.strip():
|
||||
content_parts.append(content)
|
||||
|
||||
if not content_parts:
|
||||
return ""
|
||||
|
||||
full_content = "\n".join(content_parts)
|
||||
|
||||
# Extract "User asked:" for context
|
||||
user_prompt = ""
|
||||
if "User asked:" in full_content:
|
||||
idx = full_content.index("User asked:") + len("User asked:")
|
||||
end = full_content.find("\n\n", idx)
|
||||
user_prompt = full_content[idx:end if end > 0 else None].strip()
|
||||
|
||||
narrative_parts: list[str] = []
|
||||
|
||||
if user_prompt:
|
||||
truncated = user_prompt[:120] + ("..." if len(user_prompt) > 120 else "")
|
||||
narrative_parts.append(f"Regarding your question: {truncated}")
|
||||
|
||||
# Extract head summaries
|
||||
if "Head summaries:" in full_content:
|
||||
start = full_content.index("Head summaries:") + len("Head summaries:")
|
||||
end = full_content.find("\n\nAgreed claims:", start)
|
||||
if end < 0:
|
||||
end = full_content.find("Agreed claims:", start)
|
||||
if end < 0:
|
||||
end = len(full_content)
|
||||
summaries = full_content[start:end].strip()
|
||||
for line in summaries.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("-") and ":" in line:
|
||||
narrative_parts.append(line[1:].strip())
|
||||
|
||||
# Extract agreed claims as key points
|
||||
if "Agreed claims:" in full_content:
|
||||
start = full_content.index("Agreed claims:") + len("Agreed claims:")
|
||||
rest = full_content[start:].strip()
|
||||
claims_section = rest.split("\n\nDisputed:")[0].split("\n\n")[0]
|
||||
claim_lines = [ln.strip()[1:].strip() for ln in claims_section.split("\n") if ln.strip().startswith("-")]
|
||||
for c in claim_lines[:5]:
|
||||
if " (confidence:" in c:
|
||||
c = c.split(" (confidence:")[0].strip()
|
||||
if c:
|
||||
narrative_parts.append(c)
|
||||
|
||||
if not narrative_parts:
|
||||
paragraphs = [p.strip() for p in full_content.split("\n\n") if len(p.strip()) > 20]
|
||||
narrative_parts = paragraphs[:5] if paragraphs else [full_content[:500]]
|
||||
|
||||
return "\n\n".join(narrative_parts)
|
||||
|
||||
|
||||
class NativeAdapter(LLMAdapter):
|
||||
"""
|
||||
Adapter that uses FusionAGI's native synthesis—no external LLM calls.
|
||||
|
||||
For complete(): synthesizes narrative from message content.
|
||||
For complete_structured(): returns None (use NativeReasoningProvider for heads).
|
||||
"""
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Synthesize response from message content using native logic."""
|
||||
return _synthesize_from_messages(messages)
|
||||
|
||||
def complete_structured(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
schema: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Not supported; use NativeReasoningProvider for structured HeadOutput."""
|
||||
return None
|
||||
261
fusionagi/adapters/openai_adapter.py
Normal file
261
fusionagi/adapters/openai_adapter.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""OpenAI LLM adapter with error handling and retry logic."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class OpenAIAdapterError(Exception):
|
||||
"""Base exception for OpenAI adapter errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRateLimitError(OpenAIAdapterError):
|
||||
"""Raised when rate limited by OpenAI API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAuthenticationError(OpenAIAdapterError):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIAdapter(LLMAdapter):
|
||||
"""
|
||||
OpenAI API adapter with retry logic and error handling.
|
||||
|
||||
Requires openai package and OPENAI_API_KEY.
|
||||
|
||||
Features:
|
||||
- Automatic retry with exponential backoff for transient errors
|
||||
- Proper error classification (rate limits, auth errors, etc.)
|
||||
- Structured output support via complete_structured()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_multiplier: float = 2.0,
|
||||
max_retry_delay: float = 30.0,
|
||||
**client_kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the OpenAI adapter.
|
||||
|
||||
Args:
|
||||
model: Default model to use (e.g., "gpt-4o-mini", "gpt-4o").
|
||||
api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var.
|
||||
max_retries: Maximum number of retry attempts for transient errors.
|
||||
retry_delay: Initial delay between retries in seconds.
|
||||
retry_multiplier: Multiplier for exponential backoff.
|
||||
max_retry_delay: Maximum delay between retries.
|
||||
**client_kwargs: Additional arguments passed to OpenAI client.
|
||||
"""
|
||||
self._model = model
|
||||
self._api_key = api_key
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay = retry_delay
|
||||
self._retry_multiplier = retry_multiplier
|
||||
self._max_retry_delay = max_retry_delay
|
||||
self._client_kwargs = client_kwargs
|
||||
self._client: Any = None
|
||||
self._openai_module: Any = None
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
if self._client is None:
|
||||
try:
|
||||
import openai
|
||||
self._openai_module = openai
|
||||
self._client = openai.OpenAI(api_key=self._api_key, **self._client_kwargs)
|
||||
except ImportError as e:
|
||||
raise ImportError("Install with: pip install fusionagi[openai]") from e
|
||||
return self._client
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""Check if an error is retryable (transient)."""
|
||||
if self._openai_module is None:
|
||||
return False
|
||||
|
||||
# Rate limit errors are retryable
|
||||
if hasattr(self._openai_module, "RateLimitError"):
|
||||
if isinstance(error, self._openai_module.RateLimitError):
|
||||
return True
|
||||
|
||||
# API connection errors are retryable
|
||||
if hasattr(self._openai_module, "APIConnectionError"):
|
||||
if isinstance(error, self._openai_module.APIConnectionError):
|
||||
return True
|
||||
|
||||
# Internal server errors are retryable
|
||||
if hasattr(self._openai_module, "InternalServerError"):
|
||||
if isinstance(error, self._openai_module.InternalServerError):
|
||||
return True
|
||||
|
||||
# Timeout errors are retryable
|
||||
if hasattr(self._openai_module, "APITimeoutError"):
|
||||
if isinstance(error, self._openai_module.APITimeoutError):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _classify_error(self, error: Exception) -> Exception:
|
||||
"""Convert OpenAI exceptions to adapter exceptions."""
|
||||
if self._openai_module is None:
|
||||
return OpenAIAdapterError(str(error))
|
||||
|
||||
if hasattr(self._openai_module, "RateLimitError"):
|
||||
if isinstance(error, self._openai_module.RateLimitError):
|
||||
return OpenAIRateLimitError(str(error))
|
||||
|
||||
if hasattr(self._openai_module, "AuthenticationError"):
|
||||
if isinstance(error, self._openai_module.AuthenticationError):
|
||||
return OpenAIAuthenticationError(str(error))
|
||||
|
||||
return OpenAIAdapterError(str(error))
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Call OpenAI chat completion with retry logic.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
**kwargs: Additional arguments for the API call (e.g., temperature).
|
||||
|
||||
Returns:
|
||||
The assistant's response content.
|
||||
|
||||
Raises:
|
||||
OpenAIAuthenticationError: If authentication fails.
|
||||
OpenAIRateLimitError: If rate limited after all retries.
|
||||
OpenAIAdapterError: For other API errors after all retries.
|
||||
"""
|
||||
# Validate messages format
|
||||
if not messages:
|
||||
logger.warning("OpenAI complete called with empty messages")
|
||||
return ""
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict):
|
||||
raise ValueError(f"Message {i} must be a dict, got {type(msg).__name__}")
|
||||
if "role" not in msg:
|
||||
raise ValueError(f"Message {i} missing 'role' key")
|
||||
if "content" not in msg:
|
||||
raise ValueError(f"Message {i} missing 'content' key")
|
||||
|
||||
client = self._get_client()
|
||||
model = kwargs.get("model", self._model)
|
||||
call_kwargs = {**kwargs, "model": model}
|
||||
|
||||
last_error: Exception | None = None
|
||||
delay = self._retry_delay
|
||||
|
||||
for attempt in range(self._max_retries + 1):
|
||||
try:
|
||||
resp = client.chat.completions.create(
|
||||
messages=messages,
|
||||
**call_kwargs,
|
||||
)
|
||||
choice = resp.choices[0] if resp.choices else None
|
||||
if choice and choice.message and choice.message.content:
|
||||
return choice.message.content
|
||||
logger.debug("OpenAI empty response", extra={"model": model, "attempt": attempt})
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Don't retry authentication errors
|
||||
if self._openai_module and hasattr(self._openai_module, "AuthenticationError"):
|
||||
if isinstance(e, self._openai_module.AuthenticationError):
|
||||
logger.error("OpenAI authentication failed", extra={"error": str(e)})
|
||||
raise OpenAIAuthenticationError(str(e)) from e
|
||||
|
||||
# Check if retryable
|
||||
if not self._is_retryable_error(e):
|
||||
logger.error(
|
||||
"OpenAI non-retryable error",
|
||||
extra={"error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
raise self._classify_error(e) from e
|
||||
|
||||
# Log retry attempt
|
||||
if attempt < self._max_retries:
|
||||
logger.warning(
|
||||
"OpenAI retryable error, retrying",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self._max_retries,
|
||||
"delay": delay,
|
||||
},
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay = min(delay * self._retry_multiplier, self._max_retry_delay)
|
||||
|
||||
# All retries exhausted
|
||||
logger.error(
|
||||
"OpenAI all retries exhausted",
|
||||
extra={"error": str(last_error), "attempts": self._max_retries + 1},
|
||||
)
|
||||
raise self._classify_error(last_error) from last_error
|
||||
|
||||
def complete_structured(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
schema: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Call OpenAI with JSON mode for structured output.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
schema: Optional JSON schema for response validation (informational).
|
||||
**kwargs: Additional arguments for the API call.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response or None if parsing fails.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Enable JSON mode
|
||||
call_kwargs = {**kwargs, "response_format": {"type": "json_object"}}
|
||||
|
||||
# Add schema hint to system message if provided
|
||||
if schema and messages:
|
||||
schema_hint = f"\n\nRespond with JSON matching this schema: {json.dumps(schema)}"
|
||||
if messages[0].get("role") == "system":
|
||||
messages = [
|
||||
{**messages[0], "content": messages[0]["content"] + schema_hint},
|
||||
*messages[1:],
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": f"You must respond with valid JSON.{schema_hint}"},
|
||||
*messages,
|
||||
]
|
||||
|
||||
raw = self.complete(messages, **call_kwargs)
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
"OpenAI JSON parse failed",
|
||||
extra={"error": str(e), "raw_response": raw[:200]},
|
||||
)
|
||||
return None
|
||||
67
fusionagi/adapters/stub_adapter.py
Normal file
67
fusionagi/adapters/stub_adapter.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Stub LLM adapter for tests; returns fixed responses."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
|
||||
|
||||
class StubAdapter(LLMAdapter):
|
||||
"""
|
||||
Returns configurable fixed responses; no API calls.
|
||||
|
||||
Useful for testing without making actual LLM API calls.
|
||||
Supports both text and structured (JSON) responses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: str = "Stub response",
|
||||
structured_response: dict[str, Any] | list[Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the stub adapter.
|
||||
|
||||
Args:
|
||||
response: Fixed text response for complete().
|
||||
structured_response: Fixed structured response for complete_structured().
|
||||
"""
|
||||
self._response = response
|
||||
self._structured_response = structured_response
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Return the configured stub response."""
|
||||
return self._response
|
||||
|
||||
def complete_structured(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
schema: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Return the configured structured response.
|
||||
|
||||
If no structured_response was configured, attempts to parse
|
||||
the text response as JSON, or returns None.
|
||||
"""
|
||||
if self._structured_response is not None:
|
||||
return self._structured_response
|
||||
|
||||
# Try to parse text response as JSON
|
||||
try:
|
||||
return json.loads(self._response)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def set_response(self, response: str) -> None:
|
||||
"""Update the text response (useful for test scenarios)."""
|
||||
self._response = response
|
||||
|
||||
def set_structured_response(self, response: dict[str, Any] | list[Any] | None) -> None:
|
||||
"""Update the structured response (useful for test scenarios)."""
|
||||
self._structured_response = response
|
||||
21
fusionagi/agents/__init__.py
Normal file
21
fusionagi/agents/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Agents: base, planner, reasoner, executor, critic, adversarial reviewer, head, witness. See fusionagi.multi_agent for Supervisor, Coordinator, Pool."""
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.agents.planner import PlannerAgent
|
||||
from fusionagi.agents.reasoner import ReasonerAgent
|
||||
from fusionagi.agents.executor import ExecutorAgent
|
||||
from fusionagi.agents.critic import CriticAgent
|
||||
from fusionagi.agents.adversarial_reviewer import AdversarialReviewerAgent
|
||||
from fusionagi.agents.head_agent import HeadAgent
|
||||
from fusionagi.agents.witness_agent import WitnessAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"PlannerAgent",
|
||||
"ReasonerAgent",
|
||||
"ExecutorAgent",
|
||||
"CriticAgent",
|
||||
"AdversarialReviewerAgent",
|
||||
"HeadAgent",
|
||||
"WitnessAgent",
|
||||
]
|
||||
15
fusionagi/agents/adversarial_reviewer.py
Normal file
15
fusionagi/agents/adversarial_reviewer.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.schemas.messages import AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
import json
|
||||
|
||||
class AdversarialReviewerAgent(BaseAgent):
|
||||
def __init__(self, identity="adversarial_reviewer", adapter=None):
|
||||
super().__init__(identity=identity, role="Adversarial Reviewer", objective="Find errors and risks", memory_access=True, tool_permissions=[])
|
||||
self._adapter = adapter
|
||||
def handle_message(self, envelope):
|
||||
if envelope.message.intent != "review_request":
|
||||
return None
|
||||
payload = envelope.message.payload
|
||||
issues = payload.get("issues", ["Enable LLM for detailed review"])
|
||||
return envelope.create_response("review_ready", payload={"issues": issues, "risk_level": "medium"})
|
||||
29
fusionagi/agents/base_agent.py
Normal file
29
fusionagi/agents/base_agent.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Base agent interface: identity, role, objective, memory/tool scope, handle_message."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessageEnvelope
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Abstract base agent: identity, role, objective, memory access, tool permissions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str,
|
||||
role: str,
|
||||
objective: str,
|
||||
memory_access: bool | str = True,
|
||||
tool_permissions: list[str] | str | None = None,
|
||||
) -> None:
|
||||
self.identity = identity
|
||||
self.role = role
|
||||
self.objective = objective
|
||||
self.memory_access = memory_access
|
||||
self.tool_permissions = tool_permissions if tool_permissions is not None else []
|
||||
|
||||
@abstractmethod
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""Process an incoming message; return response envelope or None."""
|
||||
...
|
||||
95
fusionagi/agents/critic.py
Normal file
95
fusionagi/agents/critic.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Critic / Evaluator agent: evaluates task outcome, error analysis, suggested improvements."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class CriticAgent(BaseAgent):
|
||||
"""Evaluates task outcome and execution trace; emits evaluation_ready."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "critic",
|
||||
adapter: LLMAdapter | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
identity=identity,
|
||||
role="Critic",
|
||||
objective="Evaluate outcomes and suggest improvements",
|
||||
memory_access=True,
|
||||
tool_permissions=[],
|
||||
)
|
||||
self._adapter = adapter
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""On evaluate_request, return evaluation_ready with score, analysis, suggestions."""
|
||||
if envelope.message.intent != "evaluate_request":
|
||||
return None
|
||||
logger.info(
|
||||
"Critic handle_message",
|
||||
extra={"recipient": self.identity, "intent": envelope.message.intent},
|
||||
)
|
||||
payload = envelope.message.payload
|
||||
task_id = envelope.task_id
|
||||
outcome = payload.get("outcome", "unknown")
|
||||
trace = payload.get("trace", [])
|
||||
plan = payload.get("plan")
|
||||
if self._adapter:
|
||||
evaluation = self._evaluate_with_llm(outcome, trace, plan)
|
||||
else:
|
||||
evaluation = {
|
||||
"success": outcome == "completed",
|
||||
"score": 1.0 if outcome == "completed" else 0.0,
|
||||
"error_analysis": [],
|
||||
"suggestions": ["Enable LLM for detailed evaluation"],
|
||||
}
|
||||
logger.info(
|
||||
"Critic response",
|
||||
extra={"recipient": self.identity, "response_intent": "evaluation_ready"},
|
||||
)
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="evaluation_ready",
|
||||
payload={"evaluation": evaluation},
|
||||
),
|
||||
task_id=task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
def _evaluate_with_llm(
|
||||
self,
|
||||
outcome: str,
|
||||
trace: list[dict[str, Any]],
|
||||
plan: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Use adapter to produce evaluation (score, error_analysis, suggestions)."""
|
||||
context = f"Outcome: {outcome}\nTrace (last 5): {json.dumps(trace[-5:], default=str)}\n"
|
||||
if plan:
|
||||
context += f"Plan: {json.dumps(plan.get('steps', [])[:5], default=str)}"
|
||||
messages = [
|
||||
{"role": "system", "content": "You evaluate task execution. Output JSON: {\"success\": bool, \"score\": 0-1, \"error_analysis\": [], \"suggestions\": []}. Output only JSON."},
|
||||
{"role": "user", "content": context},
|
||||
]
|
||||
try:
|
||||
raw = self._adapter.complete(messages)
|
||||
for start in ("```json", "```"):
|
||||
if raw.strip().startswith(start):
|
||||
raw = raw.strip()[len(start):].strip()
|
||||
if raw.endswith("```"):
|
||||
raw = raw[:-3].strip()
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
logger.exception("Critic evaluation parse failed, using fallback")
|
||||
return {
|
||||
"success": outcome == "completed",
|
||||
"score": 0.5,
|
||||
"error_analysis": ["Evaluation parse failed"],
|
||||
"suggestions": [],
|
||||
}
|
||||
236
fusionagi/agents/executor.py
Normal file
236
fusionagi/agents/executor.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Executor agent: receives execute_step, invokes tool via safe runner, returns step_done/step_failed."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi.schemas.plan import Plan
|
||||
from fusionagi.planning import get_step
|
||||
from fusionagi.tools.registry import ToolRegistry
|
||||
from fusionagi.tools.runner import run_tool
|
||||
from fusionagi._logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.state_manager import StateManager
|
||||
from fusionagi.governance.guardrails import Guardrails
|
||||
from fusionagi.governance.rate_limiter import RateLimiter
|
||||
from fusionagi.governance.access_control import AccessControl
|
||||
from fusionagi.governance.override import OverrideHooks
|
||||
from fusionagi.memory.episodic import EpisodicMemory
|
||||
|
||||
|
||||
class ExecutorAgent(BaseAgent):
|
||||
"""
|
||||
Executes steps: maps step to tool call, runs via safe runner, emits step_done/step_failed.
|
||||
|
||||
Supports full governance integration:
|
||||
- Guardrails: Pre/post checks for tool invocations
|
||||
- RateLimiter: Limits tool invocation rate per agent/tool
|
||||
- AccessControl: Policy-based tool access control
|
||||
- OverrideHooks: Human-in-the-loop for high-risk operations
|
||||
- EpisodicMemory: Records step outcomes for learning
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "executor",
|
||||
registry: ToolRegistry | None = None,
|
||||
state_manager: StateManager | None = None,
|
||||
guardrails: Guardrails | None = None,
|
||||
rate_limiter: RateLimiter | None = None,
|
||||
access_control: AccessControl | None = None,
|
||||
override_hooks: OverrideHooks | None = None,
|
||||
episodic_memory: EpisodicMemory | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the executor agent.
|
||||
|
||||
Args:
|
||||
identity: Agent identifier.
|
||||
registry: Tool registry for tool lookup.
|
||||
state_manager: State manager for trace storage.
|
||||
guardrails: Guardrails for pre/post checks.
|
||||
rate_limiter: Rate limiter for tool invocation throttling.
|
||||
access_control: Access control for policy-based tool access.
|
||||
override_hooks: Override hooks for human-in-the-loop.
|
||||
episodic_memory: Episodic memory for recording step outcomes.
|
||||
"""
|
||||
super().__init__(
|
||||
identity=identity,
|
||||
role="Executor",
|
||||
objective="Execute plan steps via tools",
|
||||
memory_access=False,
|
||||
tool_permissions=["*"],
|
||||
)
|
||||
self._registry = registry or ToolRegistry()
|
||||
self._state = state_manager
|
||||
self._guardrails = guardrails
|
||||
self._rate_limiter = rate_limiter
|
||||
self._access_control = access_control
|
||||
self._override_hooks = override_hooks
|
||||
self._episodic_memory = episodic_memory
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""On execute_step, run tool and return step_done or step_failed."""
|
||||
if envelope.message.intent != "execute_step":
|
||||
return None
|
||||
logger.info(
|
||||
"Executor handle_message",
|
||||
extra={"recipient": self.identity, "intent": envelope.message.intent},
|
||||
)
|
||||
payload = envelope.message.payload
|
||||
task_id = envelope.task_id
|
||||
step_id = payload.get("step_id")
|
||||
plan_dict = payload.get("plan")
|
||||
if not step_id or not plan_dict:
|
||||
return self._fail(task_id, envelope.message.sender, step_id or "?", "missing step_id or plan")
|
||||
plan = Plan.from_dict(plan_dict)
|
||||
step = get_step(plan, step_id)
|
||||
if not step:
|
||||
return self._fail(task_id, envelope.message.sender, step_id, "step not found")
|
||||
tool_name = step.tool_name or payload.get("tool_name")
|
||||
tool_args = step.tool_args or payload.get("tool_args", {})
|
||||
if not tool_name:
|
||||
return self._fail(task_id, envelope.message.sender, step_id, "no tool_name")
|
||||
tool = self._registry.get(tool_name)
|
||||
if not tool:
|
||||
return self._fail(task_id, envelope.message.sender, step_id, f"tool not found: {tool_name}")
|
||||
|
||||
# Check tool registry permissions
|
||||
if not self._registry.allowed_for(tool_name, self.tool_permissions):
|
||||
return self._fail(task_id, envelope.message.sender, step_id, "permission denied")
|
||||
|
||||
# Check access control policy
|
||||
if self._access_control is not None:
|
||||
if not self._access_control.allowed(self.identity, tool_name, task_id):
|
||||
logger.info(
|
||||
"Executor access_control denied",
|
||||
extra={"tool_name": tool_name, "agent_id": self.identity, "task_id": task_id},
|
||||
)
|
||||
return self._fail(task_id, envelope.message.sender, step_id, "access control denied")
|
||||
|
||||
# Check rate limiter
|
||||
if self._rate_limiter is not None:
|
||||
rate_key = f"{self.identity}:{tool_name}"
|
||||
allowed, reason = self._rate_limiter.allow(rate_key)
|
||||
if not allowed:
|
||||
logger.info(
|
||||
"Executor rate_limiter denied",
|
||||
extra={"tool_name": tool_name, "key": rate_key, "reason": reason},
|
||||
)
|
||||
return self._fail(task_id, envelope.message.sender, step_id, reason)
|
||||
|
||||
# Check guardrails pre-check
|
||||
if self._guardrails is not None:
|
||||
pre_result = self._guardrails.pre_check(tool_name, tool_args)
|
||||
logger.info(
|
||||
"Executor guardrail pre_check",
|
||||
extra={"tool_name": tool_name, "allowed": pre_result.allowed},
|
||||
)
|
||||
if not pre_result.allowed:
|
||||
return self._fail(
|
||||
task_id, envelope.message.sender, step_id,
|
||||
pre_result.error_message or "Guardrails pre-check failed",
|
||||
)
|
||||
if pre_result.sanitized_args is not None:
|
||||
tool_args = pre_result.sanitized_args
|
||||
|
||||
# Check override hooks for high-risk operations
|
||||
if self._override_hooks is not None and tool.manufacturing:
|
||||
proceed = self._override_hooks.fire(
|
||||
"tool_execution",
|
||||
{"tool_name": tool_name, "args": tool_args, "task_id": task_id, "step_id": step_id},
|
||||
)
|
||||
if not proceed:
|
||||
logger.info(
|
||||
"Executor override_hooks blocked",
|
||||
extra={"tool_name": tool_name, "step_id": step_id},
|
||||
)
|
||||
return self._fail(
|
||||
task_id, envelope.message.sender, step_id,
|
||||
"Override hook blocked execution",
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result, log_entry = run_tool(tool, tool_args)
|
||||
logger.info(
|
||||
"Executor tool run",
|
||||
extra={"tool_name": tool_name, "step_id": step_id, "error": log_entry.get("error")},
|
||||
)
|
||||
|
||||
# Check guardrails post-check
|
||||
if self._guardrails is not None and not log_entry.get("error"):
|
||||
post_ok, post_reason = self._guardrails.post_check(tool_name, result)
|
||||
if not post_ok:
|
||||
log_entry["error"] = f"Post-check failed: {post_reason}"
|
||||
log_entry["post_check_failed"] = True
|
||||
logger.info(
|
||||
"Executor guardrail post_check failed",
|
||||
extra={"tool_name": tool_name, "reason": post_reason},
|
||||
)
|
||||
|
||||
# Record trace in state manager
|
||||
if self._state:
|
||||
self._state.append_trace(task_id or "", log_entry)
|
||||
|
||||
# Record in episodic memory
|
||||
if self._episodic_memory:
|
||||
self._episodic_memory.append(
|
||||
task_id=task_id or "",
|
||||
event={
|
||||
"type": "step_execution",
|
||||
"step_id": step_id,
|
||||
"tool_name": tool_name,
|
||||
"success": not log_entry.get("error"),
|
||||
"duration_seconds": log_entry.get("duration_seconds"),
|
||||
},
|
||||
)
|
||||
|
||||
if log_entry.get("error"):
|
||||
return self._fail(
|
||||
task_id, envelope.message.sender, step_id,
|
||||
log_entry["error"],
|
||||
log_entry=log_entry,
|
||||
)
|
||||
logger.info(
|
||||
"Executor response",
|
||||
extra={"recipient": self.identity, "response_intent": "step_done"},
|
||||
)
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="step_done",
|
||||
payload={
|
||||
"step_id": step_id,
|
||||
"result": result,
|
||||
"log_entry": log_entry,
|
||||
},
|
||||
),
|
||||
task_id=task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
def _fail(
|
||||
self,
|
||||
task_id: str | None,
|
||||
recipient: str,
|
||||
step_id: str,
|
||||
error: str,
|
||||
log_entry: dict[str, Any] | None = None,
|
||||
) -> AgentMessageEnvelope:
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=recipient,
|
||||
intent="step_failed",
|
||||
payload={
|
||||
"step_id": step_id,
|
||||
"error": error,
|
||||
"log_entry": log_entry or {},
|
||||
},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
232
fusionagi/agents/head_agent.py
Normal file
232
fusionagi/agents/head_agent.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Dvādaśa head agent base: structured output via LLM or native reasoning."""
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim, HeadRisk
|
||||
from fusionagi.schemas.grounding import Citation
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ReasoningProvider(Protocol):
|
||||
"""Protocol for native reasoning: produce HeadOutput without external APIs."""
|
||||
|
||||
def produce_head_output(self, head_id: HeadId, prompt: str) -> HeadOutput:
|
||||
"""Produce structured HeadOutput for the given head and prompt."""
|
||||
...
|
||||
|
||||
|
||||
def _head_output_json_schema() -> dict[str, Any]:
|
||||
"""JSON schema for HeadOutput for LLM structured generation."""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": ["head_id", "summary"],
|
||||
"properties": {
|
||||
"head_id": {
|
||||
"type": "string",
|
||||
"enum": [h.value for h in HeadId if h != HeadId.WITNESS],
|
||||
},
|
||||
"summary": {"type": "string"},
|
||||
"claims": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"claim_text": {"type": "string"},
|
||||
"confidence": {"type": "number", "minimum": 0, "maximum": 1},
|
||||
"evidence": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_id": {"type": "string"},
|
||||
"excerpt": {"type": "string"},
|
||||
"confidence": {"type": "number"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"assumptions": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"risks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {"type": "string"},
|
||||
"severity": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"questions": {"type": "array", "items": {"type": "string"}},
|
||||
"recommended_actions": {"type": "array", "items": {"type": "string"}},
|
||||
"tone_guidance": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class HeadAgent(BaseAgent):
|
||||
"""
|
||||
Dvādaśa head agent: produces structured HeadOutput from user prompt.
|
||||
Uses LLMAdapter.complete_structured with JSON schema.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_id: HeadId,
|
||||
role: str,
|
||||
objective: str,
|
||||
system_prompt: str,
|
||||
adapter: LLMAdapter | None = None,
|
||||
tool_permissions: list[str] | None = None,
|
||||
reasoning_provider: "ReasoningProvider | None" = None,
|
||||
) -> None:
|
||||
if head_id == HeadId.WITNESS:
|
||||
raise ValueError("HeadAgent is for content heads only; use WitnessAgent for Witness")
|
||||
super().__init__(
|
||||
identity=head_id.value,
|
||||
role=role,
|
||||
objective=objective,
|
||||
memory_access=True,
|
||||
tool_permissions=tool_permissions or [],
|
||||
)
|
||||
self._head_id = head_id
|
||||
self._system_prompt = system_prompt
|
||||
self._adapter = adapter
|
||||
self._reasoning_provider = reasoning_provider
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""On head_request, produce HeadOutput and return head_output envelope."""
|
||||
if envelope.message.intent != "head_request":
|
||||
return None
|
||||
|
||||
payload = envelope.message.payload or {}
|
||||
user_prompt = payload.get("prompt", "")
|
||||
|
||||
logger.info(
|
||||
"HeadAgent handle_message",
|
||||
extra={"head_id": self._head_id.value, "intent": envelope.message.intent},
|
||||
)
|
||||
|
||||
output = self._produce_output(user_prompt)
|
||||
if output is None:
|
||||
return envelope.create_response(
|
||||
"head_failed",
|
||||
payload={"error": "Failed to produce head output", "head_id": self._head_id.value},
|
||||
)
|
||||
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="head_output",
|
||||
payload={"head_output": output.model_dump()},
|
||||
),
|
||||
task_id=envelope.task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
def _produce_output(self, user_prompt: str) -> HeadOutput | None:
|
||||
"""Produce HeadOutput via native reasoning or LLM adapter."""
|
||||
# Prefer native reasoning when available (no external APIs)
|
||||
if self._reasoning_provider is not None:
|
||||
try:
|
||||
return self._reasoning_provider.produce_head_output(
|
||||
self._head_id, user_prompt or "(No prompt provided)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Native reasoning failed, falling back",
|
||||
extra={"head_id": self._head_id.value, "error": str(e)},
|
||||
)
|
||||
|
||||
if not self._adapter:
|
||||
return self._fallback_output(user_prompt)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self._system_prompt},
|
||||
{"role": "user", "content": user_prompt or "(No prompt provided)"},
|
||||
]
|
||||
|
||||
raw = self._adapter.complete_structured(
|
||||
messages,
|
||||
schema=_head_output_json_schema(),
|
||||
temperature=0.3,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
logger.warning(
|
||||
"HeadAgent structured output invalid",
|
||||
extra={"head_id": self._head_id.value, "raw_type": type(raw).__name__},
|
||||
)
|
||||
return self._fallback_output(user_prompt)
|
||||
|
||||
return self._parse_output(raw)
|
||||
|
||||
def _parse_output(self, raw: dict[str, Any]) -> HeadOutput | None:
|
||||
"""Parse raw dict into HeadOutput."""
|
||||
try:
|
||||
claims = []
|
||||
for c in raw.get("claims", []):
|
||||
evidence = [
|
||||
Citation(
|
||||
source_id=e.get("source_id", ""),
|
||||
excerpt=e.get("excerpt", ""),
|
||||
confidence=e.get("confidence", 1.0),
|
||||
)
|
||||
for e in c.get("evidence", [])
|
||||
]
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text=c.get("claim_text", ""),
|
||||
confidence=float(c.get("confidence", 0.5)),
|
||||
evidence=evidence,
|
||||
assumptions=c.get("assumptions", []),
|
||||
)
|
||||
)
|
||||
|
||||
risks = [
|
||||
HeadRisk(
|
||||
description=r.get("description", ""),
|
||||
severity=r.get("severity", "medium"),
|
||||
)
|
||||
for r in raw.get("risks", [])
|
||||
]
|
||||
|
||||
return HeadOutput(
|
||||
head_id=self._head_id,
|
||||
summary=raw.get("summary", "No summary"),
|
||||
claims=claims,
|
||||
risks=risks,
|
||||
questions=raw.get("questions", []),
|
||||
recommended_actions=raw.get("recommended_actions", []),
|
||||
tone_guidance=raw.get("tone_guidance", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"HeadAgent parse_output failed",
|
||||
extra={"head_id": self._head_id.value, "error": str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
def _fallback_output(self, user_prompt: str) -> HeadOutput:
|
||||
"""Fallback when both reasoning provider and adapter fail or are absent."""
|
||||
return HeadOutput(
|
||||
head_id=self._head_id,
|
||||
summary=f"{self.role} head: Unable to produce structured analysis for this prompt.",
|
||||
claims=[
|
||||
HeadClaim(
|
||||
claim_text="Analysis requires reasoning provider or LLM adapter.",
|
||||
confidence=0.0,
|
||||
evidence=[],
|
||||
assumptions=[],
|
||||
),
|
||||
],
|
||||
risks=[HeadRisk(description="No reasoning provider or adapter configured", severity="high")],
|
||||
questions=[],
|
||||
recommended_actions=["Configure NativeReasoningProvider or an LLM adapter for this head"],
|
||||
tone_guidance="",
|
||||
)
|
||||
104
fusionagi/agents/heads/__init__.py
Normal file
104
fusionagi/agents/heads/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Dvādaśa content head agents: Logic, Research, Systems, Strategy, etc."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.agents.head_agent import HeadAgent
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.reasoning.native import NativeReasoningProvider
|
||||
from fusionagi.schemas.head import HeadId
|
||||
from fusionagi.prompts.heads import get_head_prompt
|
||||
|
||||
|
||||
def create_head_agent(
|
||||
head_id: HeadId,
|
||||
adapter: LLMAdapter | None = None,
|
||||
tool_permissions: list[str] | None = None,
|
||||
reasoning_provider: NativeReasoningProvider | None = None,
|
||||
use_native_reasoning: bool = True,
|
||||
) -> HeadAgent:
|
||||
"""Create a HeadAgent for the given head_id.
|
||||
|
||||
When adapter is None and use_native_reasoning is True, uses NativeReasoningProvider
|
||||
for independent symbolic reasoning (no external LLM calls).
|
||||
"""
|
||||
if head_id == HeadId.WITNESS:
|
||||
raise ValueError("Use WitnessAgent for Witness; HeadAgent is for content heads only")
|
||||
|
||||
# Use native reasoning when no adapter and use_native_reasoning, or when explicitly provided
|
||||
provider = reasoning_provider
|
||||
if provider is None and use_native_reasoning and adapter is None:
|
||||
provider = NativeReasoningProvider()
|
||||
|
||||
role_map = {
|
||||
HeadId.LOGIC: ("Logic", "Correctness, contradictions, formal checks"),
|
||||
HeadId.RESEARCH: ("Research", "Retrieval, source quality, citations"),
|
||||
HeadId.SYSTEMS: ("Systems", "Architecture, dependencies, scalability"),
|
||||
HeadId.STRATEGY: ("Strategy", "Roadmap, prioritization, tradeoffs"),
|
||||
HeadId.PRODUCT: ("Product/UX", "Interaction design, user flows"),
|
||||
HeadId.SECURITY: ("Security", "Threats, auth, secrets, abuse vectors"),
|
||||
HeadId.SAFETY: ("Safety/Ethics", "Policy alignment, harmful content prevention"),
|
||||
HeadId.RELIABILITY: ("Reliability", "SLOs, failover, load testing, observability"),
|
||||
HeadId.COST: ("Cost/Performance", "Token budgets, caching, model routing"),
|
||||
HeadId.DATA: ("Data/Memory", "Schemas, privacy, retention, personalization"),
|
||||
HeadId.DEVEX: ("DevEx", "CI/CD, testing strategy, local tooling"),
|
||||
}
|
||||
role, objective = role_map.get(
|
||||
head_id,
|
||||
(head_id.value.title(), "Provide analysis from your perspective."),
|
||||
)
|
||||
|
||||
return HeadAgent(
|
||||
head_id=head_id,
|
||||
role=role,
|
||||
objective=objective,
|
||||
system_prompt=get_head_prompt(head_id),
|
||||
adapter=adapter,
|
||||
tool_permissions=tool_permissions,
|
||||
reasoning_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
# Heads that may call tools (Research, Systems, Security, Data)
|
||||
TOOL_ENABLED_HEADS: list[HeadId] = [
|
||||
HeadId.RESEARCH,
|
||||
HeadId.SYSTEMS,
|
||||
HeadId.SECURITY,
|
||||
HeadId.DATA,
|
||||
]
|
||||
|
||||
DEFAULT_TOOL_PERMISSIONS: dict[HeadId, list[str]] = {
|
||||
HeadId.RESEARCH: ["search", "docs"],
|
||||
HeadId.SYSTEMS: ["search", "db"],
|
||||
HeadId.SECURITY: ["search", "code_runner"],
|
||||
HeadId.DATA: ["db", "docs"],
|
||||
}
|
||||
|
||||
|
||||
def create_all_content_heads(
|
||||
adapter: LLMAdapter | None = None,
|
||||
tool_permissions_by_head: dict[HeadId, list[str]] | None = None,
|
||||
reasoning_provider: NativeReasoningProvider | None = None,
|
||||
use_native_reasoning: bool = True,
|
||||
) -> dict[HeadId, HeadAgent]:
|
||||
"""Create all 11 content head agents. Tool-enabled heads get default permissions.
|
||||
|
||||
When adapter is None, uses native reasoning by default (no external LLM).
|
||||
"""
|
||||
content_heads = [h for h in HeadId if h != HeadId.WITNESS]
|
||||
perms = tool_permissions_by_head or {}
|
||||
return {
|
||||
hid: create_head_agent(
|
||||
hid,
|
||||
adapter,
|
||||
perms.get(hid) or DEFAULT_TOOL_PERMISSIONS.get(hid),
|
||||
reasoning_provider=reasoning_provider,
|
||||
use_native_reasoning=use_native_reasoning,
|
||||
)
|
||||
for hid in content_heads
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"create_head_agent",
|
||||
"create_all_content_heads",
|
||||
]
|
||||
112
fusionagi/agents/planner.py
Normal file
112
fusionagi/agents/planner.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Planner agent: decomposes goals into plan graph; uses LLM adapter when provided."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
|
||||
PLAN_REQUEST_SYSTEM = """You are a planner. Given a goal and optional constraints, output a JSON object with this exact structure:
|
||||
{"steps": [{"id": "step_1", "description": "...", "dependencies": []}, ...], "fallback_paths": []}
|
||||
Each step has: id (string), description (string), dependencies (list of step ids that must complete first).
|
||||
Output only valid JSON, no markdown or extra text."""
|
||||
|
||||
|
||||
class PlannerAgent(BaseAgent):
|
||||
"""Planner: responds to plan_request with a plan; uses adapter if set, else fixed plan."""
|
||||
|
||||
DEFAULT_PLAN = {
|
||||
"steps": [
|
||||
{"id": "step_1", "description": "Analyze goal", "dependencies": []},
|
||||
{"id": "step_2", "description": "Execute primary action", "dependencies": ["step_1"]},
|
||||
{"id": "step_3", "description": "Verify result", "dependencies": ["step_2"]},
|
||||
],
|
||||
"fallback_paths": [],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "planner",
|
||||
adapter: LLMAdapter | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
identity=identity,
|
||||
role="Planner",
|
||||
objective="Decompose goals into executable steps",
|
||||
memory_access=True,
|
||||
tool_permissions=[],
|
||||
)
|
||||
self._adapter = adapter
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""On plan_request, return plan_ready with plan from adapter or default."""
|
||||
if envelope.message.intent != "plan_request":
|
||||
return None
|
||||
logger.info(
|
||||
"Planner handle_message",
|
||||
extra={"recipient": self.identity, "intent": envelope.message.intent},
|
||||
)
|
||||
goal = envelope.message.payload.get("goal", "")
|
||||
constraints = envelope.message.payload.get("constraints", [])
|
||||
plan_dict = self._get_plan(goal, constraints)
|
||||
logger.info(
|
||||
"Planner response",
|
||||
extra={"recipient": self.identity, "response_intent": "plan_ready"},
|
||||
)
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="plan_ready",
|
||||
payload={"plan": plan_dict},
|
||||
),
|
||||
task_id=envelope.task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
def _get_plan(self, goal: str, constraints: list[str]) -> dict[str, Any]:
|
||||
"""Produce plan dict: use adapter if available and parsing succeeds, else default."""
|
||||
if not self._adapter or not goal:
|
||||
return self.DEFAULT_PLAN
|
||||
user_content = f"Goal: {goal}\n"
|
||||
if constraints:
|
||||
user_content += "Constraints: " + ", ".join(constraints) + "\n"
|
||||
user_content += "Output the plan as JSON only."
|
||||
messages = [
|
||||
{"role": "system", "content": PLAN_REQUEST_SYSTEM},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
try:
|
||||
raw = self._adapter.complete(messages)
|
||||
plan_dict = self._parse_plan_response(raw)
|
||||
if plan_dict and plan_dict.get("steps"):
|
||||
return plan_dict
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Planner adapter or parse failed, using default plan",
|
||||
extra={"intent": "plan_request"},
|
||||
)
|
||||
return self.DEFAULT_PLAN
|
||||
|
||||
def _parse_plan_response(self, raw: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON plan from raw response (handle code blocks)."""
|
||||
raw = raw.strip()
|
||||
for start in ("```json", "```"):
|
||||
if raw.startswith(start):
|
||||
raw = raw[len(start) :].strip()
|
||||
if raw.endswith("```"):
|
||||
raw = raw[:-3].strip()
|
||||
match = re.search(r"\{[\s\S]*\}", raw)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("Planner JSON parse failed (match)", extra={"error": str(e)})
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("Planner JSON parse failed (raw)", extra={"error": str(e)})
|
||||
return None
|
||||
226
fusionagi/agents/reasoner.py
Normal file
226
fusionagi/agents/reasoner.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Reasoner agent: reasons over step/subgoal + context; outputs recommendation via CoT.
|
||||
|
||||
The Reasoner agent:
|
||||
- Processes reason_request messages
|
||||
- Uses Chain-of-Thought or Tree-of-Thought reasoning
|
||||
- Integrates with WorkingMemory for context
|
||||
- Records reasoning traces to EpisodicMemory
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi.reasoning import run_chain_of_thought
|
||||
from fusionagi._logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.memory.working import WorkingMemory
|
||||
from fusionagi.memory.episodic import EpisodicMemory
|
||||
|
||||
|
||||
class ReasonerAgent(BaseAgent):
|
||||
"""
|
||||
Reasoner agent: runs Chain-of-Thought reasoning and returns recommendations.
|
||||
|
||||
Features:
|
||||
- LLM-powered reasoning via CoT
|
||||
- WorkingMemory integration for context enrichment
|
||||
- EpisodicMemory integration for trace recording
|
||||
- Confidence scoring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "reasoner",
|
||||
adapter: LLMAdapter | None = None,
|
||||
working_memory: WorkingMemory | None = None,
|
||||
episodic_memory: EpisodicMemory | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Reasoner agent.
|
||||
|
||||
Args:
|
||||
identity: Agent identifier.
|
||||
adapter: LLM adapter for reasoning.
|
||||
working_memory: Working memory for context retrieval.
|
||||
episodic_memory: Episodic memory for trace recording.
|
||||
"""
|
||||
super().__init__(
|
||||
identity=identity,
|
||||
role="Reasoner",
|
||||
objective="Reason over steps and recommend next actions",
|
||||
memory_access=True,
|
||||
tool_permissions=[],
|
||||
)
|
||||
self._adapter = adapter
|
||||
self._working_memory = working_memory
|
||||
self._episodic_memory = episodic_memory
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""On reason_request, run CoT and return recommendation_ready."""
|
||||
if envelope.message.intent != "reason_request":
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Reasoner handle_message",
|
||||
extra={"recipient": self.identity, "intent": envelope.message.intent},
|
||||
)
|
||||
|
||||
payload = envelope.message.payload
|
||||
task_id = envelope.task_id or ""
|
||||
step_id = payload.get("step_id")
|
||||
subgoal = payload.get("subgoal", "")
|
||||
context = payload.get("context", "")
|
||||
|
||||
# Enrich context with working memory if available
|
||||
enriched_context = self._enrich_context(task_id, context)
|
||||
|
||||
query = subgoal or f"Consider step: {step_id}. What should we do next?"
|
||||
|
||||
if not self._adapter:
|
||||
return self._respond_without_llm(envelope, step_id)
|
||||
|
||||
# Run chain-of-thought reasoning
|
||||
response, trace = run_chain_of_thought(
|
||||
self._adapter,
|
||||
query,
|
||||
context=enriched_context or None,
|
||||
)
|
||||
|
||||
# Calculate confidence based on trace quality
|
||||
confidence = self._calculate_confidence(trace)
|
||||
|
||||
# Store reasoning in working memory
|
||||
if self._working_memory and task_id:
|
||||
self._working_memory.append(
|
||||
task_id,
|
||||
"reasoning_history",
|
||||
{
|
||||
"step_id": step_id,
|
||||
"query": query,
|
||||
"response": response[:500] if response else "",
|
||||
"confidence": confidence,
|
||||
},
|
||||
)
|
||||
|
||||
# Record to episodic memory
|
||||
if self._episodic_memory and task_id:
|
||||
self._episodic_memory.append(
|
||||
task_id=task_id,
|
||||
event={
|
||||
"type": "reasoning",
|
||||
"step_id": step_id,
|
||||
"query": query,
|
||||
"response_length": len(response) if response else 0,
|
||||
"trace_length": len(trace),
|
||||
"confidence": confidence,
|
||||
},
|
||||
event_type="reasoning_complete",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Reasoner response",
|
||||
extra={
|
||||
"recipient": self.identity,
|
||||
"response_intent": "recommendation_ready",
|
||||
"confidence": confidence,
|
||||
},
|
||||
)
|
||||
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="recommendation_ready",
|
||||
payload={
|
||||
"step_id": step_id,
|
||||
"recommendation": response,
|
||||
"trace": trace,
|
||||
"confidence": confidence,
|
||||
},
|
||||
confidence=confidence,
|
||||
),
|
||||
task_id=task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
def _enrich_context(self, task_id: str, base_context: str) -> str:
|
||||
"""Enrich context with working memory data."""
|
||||
if not self._working_memory or not task_id:
|
||||
return base_context
|
||||
|
||||
# Get context summary from working memory
|
||||
context_summary = self._working_memory.get_context_summary(task_id, max_items=5)
|
||||
|
||||
if not context_summary:
|
||||
return base_context
|
||||
|
||||
# Get recent reasoning history
|
||||
reasoning_history = self._working_memory.get_list(task_id, "reasoning_history")
|
||||
recent_reasoning = reasoning_history[-3:] if reasoning_history else []
|
||||
|
||||
enriched_parts = [base_context] if base_context else []
|
||||
|
||||
if context_summary:
|
||||
enriched_parts.append(f"\nWorking memory context: {json.dumps(context_summary, default=str)[:500]}")
|
||||
|
||||
if recent_reasoning:
|
||||
recent_summaries = [
|
||||
f"- Step {r.get('step_id', '?')}: {r.get('response', '')[:100]}"
|
||||
for r in recent_reasoning
|
||||
]
|
||||
enriched_parts.append(f"\nRecent reasoning:\n" + "\n".join(recent_summaries))
|
||||
|
||||
return "\n".join(enriched_parts)
|
||||
|
||||
def _calculate_confidence(self, trace: list[dict[str, Any]]) -> float:
|
||||
"""Calculate confidence score based on reasoning trace."""
|
||||
if not trace:
|
||||
return 0.5 # Default confidence without trace
|
||||
|
||||
# Simple heuristic: more reasoning steps = more thorough = higher confidence
|
||||
# But diminishing returns after a point
|
||||
step_count = len(trace)
|
||||
|
||||
if step_count == 0:
|
||||
return 0.3
|
||||
elif step_count == 1:
|
||||
return 0.5
|
||||
elif step_count == 2:
|
||||
return 0.7
|
||||
elif step_count <= 4:
|
||||
return 0.8
|
||||
else:
|
||||
return 0.9
|
||||
|
||||
def _respond_without_llm(
|
||||
self,
|
||||
envelope: AgentMessageEnvelope,
|
||||
step_id: str | None,
|
||||
) -> AgentMessageEnvelope:
|
||||
"""Generate response when no LLM is available."""
|
||||
logger.info(
|
||||
"Reasoner response (no adapter)",
|
||||
extra={"recipient": self.identity, "response_intent": "recommendation_ready"},
|
||||
)
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="recommendation_ready",
|
||||
payload={
|
||||
"step_id": step_id,
|
||||
"recommendation": "Proceed with execution (no LLM available for reasoning).",
|
||||
"trace": [],
|
||||
"confidence": 0.5,
|
||||
},
|
||||
confidence=0.5,
|
||||
),
|
||||
task_id=envelope.task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
219
fusionagi/agents/witness_agent.py
Normal file
219
fusionagi/agents/witness_agent.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Witness agent: meta-controller that arbitrates head outputs and produces final response."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
|
||||
# Approx 4 chars/token; limit context to ~6k tokens (~24k chars) to avoid overflow
|
||||
DEFAULT_MAX_CONTEXT_CHARS = 24_000
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput
|
||||
from fusionagi.schemas.witness import (
|
||||
AgreementMap,
|
||||
TransparencyReport,
|
||||
FinalResponse,
|
||||
)
|
||||
from fusionagi.multi_agent.consensus_engine import run_consensus
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
WITNESS_COMPOSE_SYSTEM = """You are the Witness meta-controller in a 12-headed multi-agent system.
|
||||
You receive structured outputs from specialist heads (Logic, Research, Strategy, Security, etc.).
|
||||
Your job: produce a clear, coherent final answer that synthesizes the head contributions.
|
||||
Use the agreed claims. Acknowledge disputes if any. Be concise and actionable.
|
||||
Output only the final narrative text, no JSON or meta-commentary."""
|
||||
|
||||
|
||||
class WitnessAgent(BaseAgent):
|
||||
"""
|
||||
Witness: consumes HeadOutput from content heads, runs consensus, composes FinalResponse.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: LLMAdapter | None = None,
|
||||
max_context_chars: int = DEFAULT_MAX_CONTEXT_CHARS,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
identity=HeadId.WITNESS.value,
|
||||
role="Witness",
|
||||
objective="Arbitrate head outputs, resolve conflicts, produce final narrative",
|
||||
memory_access=True,
|
||||
tool_permissions=[],
|
||||
)
|
||||
self._adapter = adapter
|
||||
self._max_context_chars = max_context_chars
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""On witness_request, produce FinalResponse from head outputs."""
|
||||
if envelope.message.intent != "witness_request":
|
||||
return None
|
||||
|
||||
payload = envelope.message.payload or {}
|
||||
head_outputs_data = payload.get("head_outputs", [])
|
||||
user_prompt = payload.get("prompt", "")
|
||||
|
||||
head_outputs: list[HeadOutput] = []
|
||||
for h in head_outputs_data:
|
||||
if isinstance(h, dict):
|
||||
try:
|
||||
head_outputs.append(HeadOutput.model_validate(h))
|
||||
except Exception as e:
|
||||
logger.warning("Witness: skip invalid HeadOutput", extra={"error": str(e)})
|
||||
elif isinstance(h, HeadOutput):
|
||||
head_outputs.append(h)
|
||||
|
||||
logger.info(
|
||||
"Witness handle_message",
|
||||
extra={"head_count": len(head_outputs), "intent": envelope.message.intent},
|
||||
)
|
||||
|
||||
response = self._produce_response(head_outputs, user_prompt)
|
||||
if response is None:
|
||||
return envelope.create_response(
|
||||
"witness_failed",
|
||||
payload={"error": "Failed to produce final response"},
|
||||
)
|
||||
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=envelope.message.sender,
|
||||
intent="witness_output",
|
||||
payload={"final_response": response.model_dump()},
|
||||
),
|
||||
task_id=envelope.task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
def _produce_response(
|
||||
self,
|
||||
head_outputs: list[HeadOutput],
|
||||
user_prompt: str,
|
||||
) -> FinalResponse | None:
|
||||
"""Run consensus and compose final answer."""
|
||||
agreement_map = run_consensus(head_outputs)
|
||||
|
||||
head_contributions: list[dict[str, Any]] = []
|
||||
for out in head_outputs:
|
||||
key_claims = [c.claim_text[:80] + "..." if len(c.claim_text) > 80 else c.claim_text for c in out.claims[:3]]
|
||||
head_contributions.append({
|
||||
"head_id": out.head_id.value,
|
||||
"summary": out.summary,
|
||||
"key_claims": key_claims,
|
||||
})
|
||||
|
||||
safety_report = self._build_safety_report(head_outputs)
|
||||
|
||||
transparency = TransparencyReport(
|
||||
head_contributions=head_contributions,
|
||||
agreement_map=agreement_map,
|
||||
safety_report=safety_report,
|
||||
confidence_score=agreement_map.confidence_score,
|
||||
)
|
||||
|
||||
final_answer = self._compose_final_answer(
|
||||
head_outputs=head_outputs,
|
||||
agreement_map=agreement_map,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
return FinalResponse(
|
||||
final_answer=final_answer,
|
||||
transparency_report=transparency,
|
||||
head_contributions=head_contributions,
|
||||
confidence_score=agreement_map.confidence_score,
|
||||
)
|
||||
|
||||
def _build_safety_report(self, head_outputs: list[HeadOutput]) -> str:
|
||||
"""Summarize safety-relevant findings from Safety head and risks."""
|
||||
safety_summaries = []
|
||||
all_risks: list[str] = []
|
||||
for out in head_outputs:
|
||||
if out.head_id == HeadId.SAFETY and out.summary:
|
||||
safety_summaries.append(out.summary)
|
||||
for r in out.risks:
|
||||
if r.severity in ("high", "critical"):
|
||||
all_risks.append(f"[{out.head_id.value}] {r.description}")
|
||||
if safety_summaries:
|
||||
return " ".join(safety_summaries)
|
||||
if all_risks:
|
||||
return "Risks identified: " + "; ".join(all_risks[:5])
|
||||
return "No significant safety concerns raised."
|
||||
|
||||
def _truncate(self, text: str, max_len: int) -> str:
|
||||
"""Truncate text with ellipsis if over max_len."""
|
||||
if len(text) <= max_len:
|
||||
return text
|
||||
return text[: max_len - 3] + "..."
|
||||
|
||||
def _build_compose_context(
|
||||
self,
|
||||
head_outputs: list[HeadOutput],
|
||||
agreement_map: AgreementMap,
|
||||
user_prompt: str,
|
||||
) -> str:
|
||||
"""Build truncated context for LLM to avoid token overflow."""
|
||||
max_chars = self._max_context_chars
|
||||
prompt_limit = min(800, max_chars // 4)
|
||||
summary_limit = min(300, max_chars // (len(head_outputs) * 2) if head_outputs else 300)
|
||||
claim_limit = min(150, max_chars // 20)
|
||||
|
||||
user_trunc = self._truncate(user_prompt, prompt_limit)
|
||||
context = f"User asked: {user_trunc}\n\n"
|
||||
context += "Head summaries:\n"
|
||||
for out in head_outputs:
|
||||
s = self._truncate(out.summary or "", summary_limit)
|
||||
context += f"- {out.head_id.value}: {s}\n"
|
||||
context += "\nAgreed claims:\n"
|
||||
for c in agreement_map.agreed_claims[:10]:
|
||||
claim = self._truncate(c.get("claim_text", ""), claim_limit)
|
||||
context += f"- {claim} (confidence: {c.get('confidence', 0)})\n"
|
||||
if agreement_map.disputed_claims:
|
||||
context += "\nDisputed:\n"
|
||||
for c in agreement_map.disputed_claims[:5]:
|
||||
claim = self._truncate(c.get("claim_text", ""), claim_limit)
|
||||
context += f"- {claim}\n"
|
||||
|
||||
if len(context) > max_chars:
|
||||
context = context[: max_chars - 20] + "\n...[truncated]"
|
||||
return context
|
||||
|
||||
def _compose_final_answer(
|
||||
self,
|
||||
head_outputs: list[HeadOutput],
|
||||
agreement_map: AgreementMap,
|
||||
user_prompt: str,
|
||||
) -> str:
|
||||
"""Compose narrative from head outputs and consensus."""
|
||||
if not self._adapter:
|
||||
return self._fallback_compose(head_outputs, agreement_map)
|
||||
|
||||
context = self._build_compose_context(head_outputs, agreement_map, user_prompt)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": WITNESS_COMPOSE_SYSTEM},
|
||||
{"role": "user", "content": context},
|
||||
]
|
||||
try:
|
||||
result = self._adapter.complete(messages, temperature=0.3)
|
||||
return result.strip() if result else self._fallback_compose(head_outputs, agreement_map)
|
||||
except Exception as e:
|
||||
logger.exception("Witness compose failed", extra={"error": str(e)})
|
||||
return self._fallback_compose(head_outputs, agreement_map)
|
||||
|
||||
def _fallback_compose(
|
||||
self,
|
||||
head_outputs: list[HeadOutput],
|
||||
agreement_map: AgreementMap,
|
||||
) -> str:
|
||||
"""Simple concatenation when no adapter."""
|
||||
parts = []
|
||||
for out in head_outputs[:5]:
|
||||
parts.append(f"[{out.head_id.value}] {out.summary}")
|
||||
if agreement_map.agreed_claims:
|
||||
parts.append("Key points: " + "; ".join(
|
||||
c.get("claim_text", "")[:60] for c in agreement_map.agreed_claims[:5]
|
||||
))
|
||||
return "\n\n".join(parts) if parts else "No head outputs available."
|
||||
5
fusionagi/api/__init__.py
Normal file
5
fusionagi/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""FusionAGI API: FastAPI gateway for Dvādaśa sessions and prompts."""
|
||||
|
||||
from fusionagi.api.app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
63
fusionagi/api/app.py
Normal file
63
fusionagi/api/app.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""FastAPI application factory for FusionAGI Dvādaśa API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.api.dependencies import SessionStore, default_orchestrator, set_app_state
|
||||
from fusionagi.api.routes import router as api_router
|
||||
|
||||
|
||||
def create_app(
|
||||
adapter: Any = None,
|
||||
cors_origins: list[str] | None = None,
|
||||
) -> Any:
|
||||
"""Create FastAPI app with Dvādaśa routes.
|
||||
|
||||
Args:
|
||||
adapter: Optional LLMAdapter for head/Witness LLM calls.
|
||||
cors_origins: Optional list of CORS allowed origins (e.g. ["*"] or ["https://example.com"]).
|
||||
If None, no CORS middleware is added.
|
||||
"""
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
except ImportError as e:
|
||||
raise ImportError("Install with: pip install fusionagi[api]") from e
|
||||
|
||||
app = FastAPI(
|
||||
title="FusionAGI Dvādaśa API",
|
||||
description="12-headed multi-agent orchestration API",
|
||||
version="0.1.0",
|
||||
)
|
||||
app.state.llm_adapter = adapter
|
||||
from fusionagi.api.dependencies import set_default_adapter
|
||||
set_default_adapter(adapter)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""Initialize orchestrator and session store."""
|
||||
if getattr(app.state, "_dvadasa_ready", False):
|
||||
return
|
||||
adapter_inner = getattr(app.state, "llm_adapter", None)
|
||||
orch, bus = default_orchestrator(adapter_inner)
|
||||
store = SessionStore()
|
||||
set_app_state(orch, bus, store)
|
||||
app.state._dvadasa_ready = True
|
||||
|
||||
app.include_router(api_router, prefix="/v1", tags=["dvadasa"])
|
||||
|
||||
if cors_origins is not None:
|
||||
try:
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
except ImportError:
|
||||
pass # CORS optional
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Default app instance for uvicorn/gunicorn
|
||||
app = create_app()
|
||||
183
fusionagi/api/dependencies.py
Normal file
183
fusionagi/api/dependencies.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""API dependencies: orchestrator, session store, guardrails."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fusionagi import Orchestrator, EventBus, StateManager
|
||||
from fusionagi.agents import WitnessAgent
|
||||
from fusionagi.agents.heads import create_all_content_heads
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.adapters.native_adapter import NativeAdapter
|
||||
from fusionagi.schemas.head import HeadId
|
||||
from fusionagi.governance import SafetyPipeline, AuditLog
|
||||
|
||||
|
||||
def _get_reasoning_provider() -> Any:
|
||||
"""Return reasoning provider based on SUPER_BIG_BRAIN_ENABLED env."""
|
||||
if os.environ.get("SUPER_BIG_BRAIN_ENABLED", "false").lower() in ("true", "1", "yes"):
|
||||
from fusionagi.core.super_big_brain import SuperBigBrainReasoningProvider
|
||||
from fusionagi.memory import SemanticGraphMemory
|
||||
return SuperBigBrainReasoningProvider(semantic_graph=SemanticGraphMemory())
|
||||
return None
|
||||
|
||||
# App state populated by lifespan or lazy init
|
||||
_app_state: dict[str, Any] = {}
|
||||
_default_adapter: Any = None
|
||||
|
||||
|
||||
def set_default_adapter(adapter: Any) -> None:
|
||||
global _default_adapter
|
||||
_default_adapter = adapter
|
||||
|
||||
|
||||
def default_orchestrator(adapter: LLMAdapter | None = None) -> tuple[Orchestrator, Any]:
|
||||
"""Create default Orchestrator with Dvādaśa heads and Witness registered.
|
||||
|
||||
When adapter is None, uses native reasoning throughout: heads use NativeReasoningProvider,
|
||||
Witness uses NativeAdapter for synthesis. No external LLM calls.
|
||||
"""
|
||||
bus = EventBus()
|
||||
state = StateManager()
|
||||
orch = Orchestrator(event_bus=bus, state_manager=state)
|
||||
|
||||
# Heads: use native or Super Big Brain reasoning when no adapter
|
||||
reasoning_provider = _get_reasoning_provider()
|
||||
heads = create_all_content_heads(
|
||||
adapter=adapter,
|
||||
reasoning_provider=reasoning_provider,
|
||||
use_native_reasoning=reasoning_provider is None,
|
||||
)
|
||||
for hid, agent in heads.items():
|
||||
orch.register_agent(hid.value, agent)
|
||||
|
||||
# Witness: use NativeAdapter when no adapter for native synthesis
|
||||
witness_adapter = adapter if adapter is not None else NativeAdapter()
|
||||
orch.register_agent(HeadId.WITNESS.value, WitnessAgent(adapter=witness_adapter))
|
||||
|
||||
return orch, bus
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""In-memory session store for API sessions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def create(self, session_id: str, user_id: str | None = None) -> dict[str, Any]:
|
||||
sess = {"session_id": session_id, "user_id": user_id, "history": []}
|
||||
self._sessions[session_id] = sess
|
||||
return sess
|
||||
|
||||
def get(self, session_id: str) -> dict[str, Any] | None:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def append_history(self, session_id: str, entry: dict[str, Any]) -> None:
|
||||
sess = self._sessions.get(session_id)
|
||||
if sess:
|
||||
sess.setdefault("history", []).append(entry)
|
||||
|
||||
|
||||
def get_orchestrator() -> Any:
|
||||
return _app_state.get("orchestrator")
|
||||
|
||||
|
||||
def get_event_bus() -> Any:
|
||||
return _app_state.get("event_bus")
|
||||
|
||||
|
||||
def get_session_store() -> SessionStore | None:
|
||||
return _app_state.get("session_store")
|
||||
|
||||
|
||||
def get_safety_pipeline() -> Any:
|
||||
return _app_state.get("safety_pipeline")
|
||||
|
||||
|
||||
def get_telemetry_tracer() -> Any:
|
||||
return _app_state.get("telemetry_tracer")
|
||||
|
||||
|
||||
def set_app_state(orchestrator: Any, event_bus: Any, session_store: SessionStore) -> None:
|
||||
_app_state["orchestrator"] = orchestrator
|
||||
_app_state["event_bus"] = event_bus
|
||||
_app_state["session_store"] = session_store
|
||||
if "safety_pipeline" not in _app_state:
|
||||
_app_state["safety_pipeline"] = SafetyPipeline(audit_log=AuditLog())
|
||||
try:
|
||||
from fusionagi.telemetry import TelemetryTracer, set_tracer
|
||||
tracer = TelemetryTracer()
|
||||
tracer.subscribe(event_bus)
|
||||
set_tracer(tracer)
|
||||
_app_state["telemetry_tracer"] = tracer
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def ensure_initialized(adapter: Any = None) -> None:
|
||||
"""Lazy init: ensure orchestrator and store exist (for TestClient)."""
|
||||
if _app_state.get("orchestrator") is not None:
|
||||
return
|
||||
adj = adapter if adapter is not None else _default_adapter
|
||||
orch, bus = default_orchestrator(adj)
|
||||
set_app_state(orch, bus, SessionStore())
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIBridgeConfig:
|
||||
"""Configuration for OpenAI-compatible API bridge."""
|
||||
|
||||
model_id: str
|
||||
auth_enabled: bool
|
||||
api_key: str | None
|
||||
timeout_per_head: float
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "OpenAIBridgeConfig":
|
||||
"""Load config from environment variables."""
|
||||
auth = os.environ.get("OPENAI_BRIDGE_AUTH", "disabled").lower()
|
||||
auth_enabled = auth not in ("disabled", "false", "0", "no")
|
||||
return cls(
|
||||
model_id=os.environ.get("OPENAI_BRIDGE_MODEL_ID", "fusionagi-dvadasa"),
|
||||
auth_enabled=auth_enabled,
|
||||
api_key=os.environ.get("OPENAI_BRIDGE_API_KEY") if auth_enabled else None,
|
||||
timeout_per_head=float(os.environ.get("OPENAI_BRIDGE_TIMEOUT_PER_HEAD", "60")),
|
||||
)
|
||||
|
||||
|
||||
def get_openai_bridge_config() -> OpenAIBridgeConfig:
|
||||
"""Return OpenAI bridge config from app state or env."""
|
||||
cfg = _app_state.get("openai_bridge_config")
|
||||
if cfg is not None:
|
||||
return cfg
|
||||
return OpenAIBridgeConfig.from_env()
|
||||
|
||||
|
||||
def verify_openai_bridge_auth(authorization: str | None) -> None:
|
||||
"""
|
||||
Verify OpenAI bridge auth. Raises HTTPException(401) if auth enabled and invalid.
|
||||
Call from route dependencies.
|
||||
"""
|
||||
try:
|
||||
from fastapi import HTTPException
|
||||
|
||||
cfg = get_openai_bridge_config()
|
||||
if not cfg.auth_enabled:
|
||||
return
|
||||
if not cfg.api_key:
|
||||
return # Auth enabled but no key configured: allow (misconfig)
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": {"message": "Missing or invalid Authorization header", "type": "authentication_error"}},
|
||||
)
|
||||
token = authorization[7:].strip()
|
||||
if token != cfg.api_key:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": {"message": "Invalid API key", "type": "authentication_error"}},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
13
fusionagi/api/openai_compat/__init__.py
Normal file
13
fusionagi/api/openai_compat/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""OpenAI-compatible API bridge for Cursor Composer and other OpenAI API consumers."""
|
||||
|
||||
from fusionagi.api.openai_compat.translators import (
|
||||
messages_to_prompt,
|
||||
estimate_usage,
|
||||
final_response_to_openai,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"messages_to_prompt",
|
||||
"estimate_usage",
|
||||
"final_response_to_openai",
|
||||
]
|
||||
146
fusionagi/api/openai_compat/translators.py
Normal file
146
fusionagi/api/openai_compat/translators.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Translators for OpenAI API request/response format to FusionAGI."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.witness import FinalResponse
|
||||
|
||||
|
||||
def _extract_content(msg: dict[str, Any]) -> str:
|
||||
"""Extract text content from a message. Handles string or array content parts."""
|
||||
content = msg.get("content")
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
parts.append(part.get("text", "") or "")
|
||||
elif isinstance(part, str):
|
||||
parts.append(part)
|
||||
return "\n".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def messages_to_prompt(messages: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
Translate OpenAI messages array to a single prompt string for Dvādaśa.
|
||||
|
||||
Format:
|
||||
[System]: {system_content}
|
||||
[User]: {user_msg_1}
|
||||
[Assistant]: {assistant_msg_1}
|
||||
[User]: {user_msg_2}
|
||||
...
|
||||
[User]: {last_user_message} <- primary goal for submit_task
|
||||
|
||||
Tool result messages (role: "tool") are appended as "Tool {name} returned: {content}".
|
||||
Falls back to last non-system message if no explicit user turn.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
|
||||
Returns:
|
||||
Single prompt string for orch.submit_task / run_dvadasa.
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
system_content = ""
|
||||
last_user_content = ""
|
||||
|
||||
for msg in messages:
|
||||
role = (msg.get("role") or "user").lower()
|
||||
content = _extract_content(msg)
|
||||
|
||||
if role == "system":
|
||||
system_content = content
|
||||
elif role == "user":
|
||||
last_user_content = content
|
||||
if system_content and not parts:
|
||||
parts.append(f"[System]: {system_content}")
|
||||
parts.append(f"[User]: {content}")
|
||||
elif role == "assistant":
|
||||
if system_content and not parts:
|
||||
parts.append(f"[System]: {system_content}")
|
||||
parts.append(f"[Assistant]: {content}")
|
||||
elif role == "tool":
|
||||
name = msg.get("name", "unknown")
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
parts.append(f"[Tool {name}]{f' (id={tool_id})' if tool_id else ''} returned: {content}")
|
||||
|
||||
if not parts:
|
||||
return last_user_content or system_content
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def estimate_usage(
|
||||
messages: list[dict[str, Any]],
|
||||
completion_text: str,
|
||||
chars_per_token: int = 4,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Estimate token usage from character counts (OpenAI-like heuristic).
|
||||
|
||||
Args:
|
||||
messages: Request messages for prompt_tokens.
|
||||
completion_text: Response text for completion_tokens.
|
||||
chars_per_token: Approximate chars per token (default 4).
|
||||
|
||||
Returns:
|
||||
Dict with prompt_tokens, completion_tokens, total_tokens.
|
||||
"""
|
||||
prompt_chars = sum(len(_extract_content(m)) for m in messages)
|
||||
completion_chars = len(completion_text)
|
||||
prompt_tokens = max(1, prompt_chars // chars_per_token)
|
||||
completion_tokens = max(1, completion_chars // chars_per_token)
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
|
||||
|
||||
def final_response_to_openai(
|
||||
final: FinalResponse,
|
||||
task_id: str,
|
||||
request_model: str | None = None,
|
||||
messages: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Map FusionAGI FinalResponse to OpenAI Chat Completion format.
|
||||
|
||||
Args:
|
||||
final: FinalResponse from run_dvadasa.
|
||||
task_id: Task ID for response id.
|
||||
request_model: Model ID from request, or default fusionagi-dvadasa.
|
||||
messages: Original request messages for usage estimation.
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible chat completion dict.
|
||||
"""
|
||||
model = request_model or "fusionagi-dvadasa"
|
||||
usage = estimate_usage(messages or [], final.final_answer)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{task_id[:24]}" if len(task_id) >= 24 else f"chatcmpl-{task_id}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": final.final_answer,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
}
|
||||
14
fusionagi/api/routes/__init__.py
Normal file
14
fusionagi/api/routes/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""API routes for Dvādaśa sessions and prompts."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fusionagi.api.routes.sessions import router as sessions_router
|
||||
from fusionagi.api.routes.tts import router as tts_router
|
||||
from fusionagi.api.routes.admin import router as admin_router
|
||||
from fusionagi.api.routes.openai_compat import router as openai_compat_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(sessions_router, prefix="/sessions", tags=["sessions"])
|
||||
router.include_router(tts_router, prefix="/sessions", tags=["tts"])
|
||||
router.include_router(admin_router, prefix="/admin", tags=["admin"])
|
||||
router.include_router(openai_compat_router)
|
||||
17
fusionagi/api/routes/admin.py
Normal file
17
fusionagi/api/routes/admin.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Admin routes: telemetry, etc."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fusionagi.api.dependencies import get_telemetry_tracer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/telemetry")
|
||||
def get_telemetry(task_id: str | None = None, limit: int = 100) -> dict:
|
||||
"""Return telemetry traces (admin). Filter by task_id if provided."""
|
||||
tracer = get_telemetry_tracer()
|
||||
if not tracer:
|
||||
return {"traces": []}
|
||||
traces = tracer.get_traces(task_id=task_id, limit=limit)
|
||||
return {"traces": traces}
|
||||
265
fusionagi/api/routes/openai_compat.py
Normal file
265
fusionagi/api/routes/openai_compat.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""OpenAI-compatible API routes for Cursor Composer and other consumers."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from fusionagi.api.dependencies import (
|
||||
ensure_initialized,
|
||||
get_event_bus,
|
||||
get_orchestrator,
|
||||
get_safety_pipeline,
|
||||
get_openai_bridge_config,
|
||||
verify_openai_bridge_auth,
|
||||
)
|
||||
from fusionagi.api.openai_compat.translators import (
|
||||
messages_to_prompt,
|
||||
final_response_to_openai,
|
||||
estimate_usage,
|
||||
)
|
||||
from fusionagi.core import run_dvadasa
|
||||
from fusionagi.schemas.commands import parse_user_input
|
||||
|
||||
router = APIRouter(tags=["openai-compat"])
|
||||
|
||||
# Chunk size for streaming (chars per SSE delta)
|
||||
_STREAM_CHUNK_SIZE = 50
|
||||
|
||||
|
||||
def _openai_error(status_code: int, message: str, error_type: str) -> HTTPException:
|
||||
"""Raise HTTPException with OpenAI-style error body."""
|
||||
return HTTPException(
|
||||
status_code=status_code,
|
||||
detail={"error": {"message": message, "type": error_type}},
|
||||
)
|
||||
|
||||
|
||||
def _ensure_openai_init() -> None:
|
||||
"""Ensure orchestrator and dependencies are initialized."""
|
||||
ensure_initialized()
|
||||
|
||||
|
||||
async def _verify_auth_dep(authorization: str | None = Header(default=None)) -> None:
|
||||
"""Dependency: verify auth for OpenAI bridge routes."""
|
||||
verify_openai_bridge_auth(authorization)
|
||||
|
||||
|
||||
@router.get("/models", dependencies=[Depends(_verify_auth_dep)])
|
||||
async def list_models() -> dict[str, Any]:
|
||||
"""
|
||||
List available models (OpenAI-compatible).
|
||||
Returns fusionagi-dvadasa as the single model.
|
||||
"""
|
||||
cfg = get_openai_bridge_config()
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": cfg.model_id,
|
||||
"object": "model",
|
||||
"created": 1704067200,
|
||||
"owned_by": "fusionagi",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
dependencies=[Depends(_verify_auth_dep)],
|
||||
response_model=None,
|
||||
)
|
||||
async def create_chat_completion(request: Request):
|
||||
"""
|
||||
Create chat completion (OpenAI-compatible).
|
||||
Supports both sync (stream=false) and streaming (stream=true).
|
||||
"""
|
||||
_ensure_openai_init()
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
raise _openai_error(400, f"Invalid JSON body: {e}", "invalid_request_error")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
raise _openai_error(
|
||||
400,
|
||||
"messages is required and must be a non-empty array",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
from fusionagi.api.openai_compat.translators import _extract_content
|
||||
|
||||
has_content = any(_extract_content(m).strip() for m in messages)
|
||||
if not has_content:
|
||||
raise _openai_error(
|
||||
400,
|
||||
"messages must contain at least one user or assistant message with content",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
prompt = messages_to_prompt(messages)
|
||||
if not prompt.strip():
|
||||
raise _openai_error(
|
||||
400,
|
||||
"messages must contain at least one user or assistant message with content",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
pipeline = get_safety_pipeline()
|
||||
if pipeline:
|
||||
pre_result = pipeline.pre_check(prompt)
|
||||
if not pre_result.allowed:
|
||||
raise _openai_error(
|
||||
400,
|
||||
pre_result.reason or "Input moderation failed",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
orch = get_orchestrator()
|
||||
bus = get_event_bus()
|
||||
if not orch:
|
||||
raise _openai_error(503, "Service not initialized", "internal_error")
|
||||
|
||||
cfg = get_openai_bridge_config()
|
||||
request_model = body.get("model") or cfg.model_id
|
||||
stream = body.get("stream", False) is True
|
||||
|
||||
task_id = orch.submit_task(goal=prompt[:200])
|
||||
parsed = parse_user_input(prompt)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_stream_chat_completion(
|
||||
orch=orch,
|
||||
bus=bus,
|
||||
task_id=task_id,
|
||||
prompt=prompt,
|
||||
parsed=parsed,
|
||||
request_model=request_model,
|
||||
messages=messages,
|
||||
pipeline=pipeline,
|
||||
cfg=cfg,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Sync path
|
||||
final = run_dvadasa(
|
||||
orchestrator=orch,
|
||||
task_id=task_id,
|
||||
user_prompt=prompt,
|
||||
parsed=parsed,
|
||||
event_bus=bus,
|
||||
timeout_per_head=cfg.timeout_per_head,
|
||||
)
|
||||
|
||||
if not final:
|
||||
raise _openai_error(500, "Dvādaśa failed to produce response", "internal_error")
|
||||
|
||||
if pipeline:
|
||||
post_result = pipeline.post_check(final.final_answer)
|
||||
if not post_result.passed:
|
||||
raise _openai_error(
|
||||
400,
|
||||
f"Output scan failed: {', '.join(post_result.flags)}",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
result = final_response_to_openai(
|
||||
final=final,
|
||||
task_id=task_id,
|
||||
request_model=request_model,
|
||||
messages=messages,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _stream_chat_completion(
|
||||
orch: Any,
|
||||
bus: Any,
|
||||
task_id: str,
|
||||
prompt: str,
|
||||
parsed: Any,
|
||||
request_model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
pipeline: Any,
|
||||
cfg: Any,
|
||||
):
|
||||
"""
|
||||
Async generator that runs Dvādaśa and streams the final_answer as SSE chunks.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def run() -> Any:
|
||||
return run_dvadasa(
|
||||
orchestrator=orch,
|
||||
task_id=task_id,
|
||||
user_prompt=prompt,
|
||||
parsed=parsed,
|
||||
event_bus=bus,
|
||||
timeout_per_head=cfg.timeout_per_head,
|
||||
)
|
||||
|
||||
try:
|
||||
final = await loop.run_in_executor(executor, run)
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': {'message': str(e), 'type': 'internal_error'}})}\n\n"
|
||||
return
|
||||
|
||||
if not final:
|
||||
yield f"data: {json.dumps({'error': {'message': 'Dvādaśa failed', 'type': 'internal_error'}})}\n\n"
|
||||
return
|
||||
|
||||
if pipeline:
|
||||
post_result = pipeline.post_check(final.final_answer)
|
||||
if not post_result.passed:
|
||||
yield f"data: {json.dumps({'error': {'message': 'Output scan failed', 'type': 'invalid_request_error'}})}\n\n"
|
||||
return
|
||||
|
||||
chat_id = f"chatcmpl-{task_id[:24]}" if len(task_id) >= 24 else f"chatcmpl-{task_id}"
|
||||
|
||||
# Stream final_answer in chunks
|
||||
text = final.final_answer
|
||||
for i in range(0, len(text), _STREAM_CHUNK_SIZE):
|
||||
chunk = text[i : i + _STREAM_CHUNK_SIZE]
|
||||
chunk_json = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": request_model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": chunk},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk_json)}\n\n"
|
||||
|
||||
# Final chunk with finish_reason
|
||||
usage = estimate_usage(messages, text)
|
||||
final_chunk = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": request_model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": usage,
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
147
fusionagi/api/routes/sessions.py
Normal file
147
fusionagi/api/routes/sessions.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Session and prompt routes."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from fusionagi.api.dependencies import get_orchestrator, get_session_store, get_event_bus, get_safety_pipeline
|
||||
from fusionagi.api.websocket import handle_stream
|
||||
from fusionagi.core import run_dvadasa, select_heads_for_complexity, extract_sources_from_head_outputs
|
||||
from fusionagi.schemas.commands import parse_user_input, UserIntent
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _ensure_init():
|
||||
from fusionagi.api.dependencies import ensure_initialized
|
||||
ensure_initialized()
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_session(user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Create a new session."""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
if not store:
|
||||
raise HTTPException(status_code=503, detail="Session store not initialized")
|
||||
session_id = str(uuid.uuid4())
|
||||
store.create(session_id, user_id)
|
||||
return {"session_id": session_id, "user_id": user_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/prompt")
|
||||
def submit_prompt(session_id: str, body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Submit a prompt and receive FinalResponse (sync)."""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
orch = get_orchestrator()
|
||||
bus = get_event_bus()
|
||||
if not store or not orch:
|
||||
raise HTTPException(status_code=503, detail="Service not initialized")
|
||||
|
||||
sess = store.get(session_id)
|
||||
if not sess:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
prompt = body.get("prompt", "")
|
||||
parsed = parse_user_input(prompt)
|
||||
|
||||
if not prompt or not parsed.cleaned_prompt.strip():
|
||||
if parsed.intent in (UserIntent.SHOW_DISSENT, UserIntent.RERUN_RISK, UserIntent.EXPLAIN_REASONING, UserIntent.SOURCES):
|
||||
hist = sess.get("history", [])
|
||||
if hist:
|
||||
prompt = hist[-1].get("prompt", "")
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=400, detail="No previous prompt; provide a prompt for this command")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="prompt is required")
|
||||
|
||||
effective_prompt = parsed.cleaned_prompt.strip() or prompt
|
||||
pipeline = get_safety_pipeline()
|
||||
if pipeline:
|
||||
pre_result = pipeline.pre_check(effective_prompt)
|
||||
if not pre_result.allowed:
|
||||
raise HTTPException(status_code=400, detail=pre_result.reason or "Input moderation failed")
|
||||
|
||||
task_id = orch.submit_task(goal=effective_prompt[:200])
|
||||
|
||||
# Dynamic head selection
|
||||
head_ids = select_heads_for_complexity(effective_prompt)
|
||||
if parsed.intent.value == "head_strategy" and parsed.head_id:
|
||||
head_ids = [parsed.head_id]
|
||||
|
||||
force_second = parsed.intent == UserIntent.RERUN_RISK
|
||||
return_heads = parsed.intent == UserIntent.SOURCES
|
||||
|
||||
result = run_dvadasa(
|
||||
orchestrator=orch,
|
||||
task_id=task_id,
|
||||
user_prompt=effective_prompt,
|
||||
parsed=parsed,
|
||||
head_ids=head_ids if parsed.intent.value != "normal" or body.get("use_all_heads") else None,
|
||||
event_bus=bus,
|
||||
force_second_pass=force_second,
|
||||
return_head_outputs=return_heads,
|
||||
)
|
||||
|
||||
if return_heads and isinstance(result, tuple):
|
||||
final, head_outputs = result
|
||||
else:
|
||||
final = result
|
||||
head_outputs = []
|
||||
|
||||
if not final:
|
||||
raise HTTPException(status_code=500, detail="Failed to produce response")
|
||||
|
||||
if pipeline:
|
||||
post_result = pipeline.post_check(final.final_answer)
|
||||
if not post_result.passed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Output scan failed: {', '.join(post_result.flags)}",
|
||||
)
|
||||
|
||||
entry = {
|
||||
"prompt": effective_prompt,
|
||||
"final_answer": final.final_answer,
|
||||
"confidence_score": final.confidence_score,
|
||||
"head_contributions": final.head_contributions,
|
||||
}
|
||||
store.append_history(session_id, entry)
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"final_answer": final.final_answer,
|
||||
"transparency_report": final.transparency_report.model_dump(),
|
||||
"head_contributions": final.head_contributions,
|
||||
"confidence_score": final.confidence_score,
|
||||
}
|
||||
if parsed.intent == UserIntent.SHOW_DISSENT:
|
||||
response["response_mode"] = "show_dissent"
|
||||
response["disputed_claims"] = final.transparency_report.agreement_map.disputed_claims
|
||||
elif parsed.intent == UserIntent.EXPLAIN_REASONING:
|
||||
response["response_mode"] = "explain"
|
||||
elif parsed.intent == UserIntent.SOURCES and head_outputs:
|
||||
response["sources"] = extract_sources_from_head_outputs(head_outputs)
|
||||
return response
|
||||
|
||||
|
||||
@router.websocket("/{session_id}/stream")
|
||||
async def stream_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||
"""WebSocket for streaming Dvādaśa response. Send {\"prompt\": \"...\"} to start."""
|
||||
await websocket.accept()
|
||||
try:
|
||||
data = await websocket.receive_json()
|
||||
prompt = data.get("prompt", "")
|
||||
async def send_evt(evt: dict) -> None:
|
||||
await websocket.send_json(evt)
|
||||
await handle_stream(session_id, prompt, send_evt)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
49
fusionagi/api/routes/tts.py
Normal file
49
fusionagi/api/routes/tts.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""TTS synthesis routes for per-head voice output."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from fusionagi.api.dependencies import get_session_store
|
||||
from fusionagi.config.head_voices import get_voice_id_for_head
|
||||
from fusionagi.schemas.head import HeadId
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/{session_id}/synthesize")
|
||||
async def synthesize(
|
||||
session_id: str,
|
||||
body: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Synthesize text to audio for a head.
|
||||
Body: { "text": "...", "head_id": "logic" }
|
||||
Returns: { "audio_base64": "..." } or { "audio_base64": null } if TTS not configured.
|
||||
"""
|
||||
store = get_session_store()
|
||||
if not store:
|
||||
raise HTTPException(status_code=503, detail="Service not initialized")
|
||||
sess = store.get(session_id)
|
||||
if not sess:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
text = body.get("text", "")
|
||||
head_id_str = body.get("head_id", "")
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="text is required")
|
||||
|
||||
try:
|
||||
head_id = HeadId(head_id_str)
|
||||
except ValueError:
|
||||
head_id = HeadId.LOGIC
|
||||
|
||||
voice_id = get_voice_id_for_head(head_id)
|
||||
audio_base64 = None
|
||||
# TODO: Wire TTSAdapter (ElevenLabs, Azure, etc.) and synthesize
|
||||
# if tts_adapter:
|
||||
# audio_bytes = await tts_adapter.synthesize(text, voice_id=voice_id)
|
||||
# if audio_bytes:
|
||||
# import base64
|
||||
# audio_base64 = base64.b64encode(audio_bytes).decode()
|
||||
return {"audio_base64": audio_base64, "voice_id": voice_id}
|
||||
99
fusionagi/api/websocket.py
Normal file
99
fusionagi/api/websocket.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""WebSocket streaming for Dvādaśa responses."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.api.dependencies import get_orchestrator, get_session_store, get_event_bus
|
||||
from fusionagi.core import run_heads_parallel, run_witness, select_heads_for_complexity
|
||||
from fusionagi.schemas.commands import parse_user_input
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput
|
||||
|
||||
|
||||
async def handle_stream(
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
send_fn: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Run Dvādaśa flow and stream events to WebSocket.
|
||||
Events: heads_running, head_complete, heads_done, witness_running, complete.
|
||||
"""
|
||||
from fusionagi.api.dependencies import ensure_initialized
|
||||
ensure_initialized()
|
||||
store = get_session_store()
|
||||
orch = get_orchestrator()
|
||||
bus = get_event_bus()
|
||||
if not store or not orch:
|
||||
await send_fn({"type": "error", "message": "Service not initialized"})
|
||||
return
|
||||
|
||||
sess = store.get(session_id)
|
||||
if not sess:
|
||||
await send_fn({"type": "error", "message": "Session not found"})
|
||||
return
|
||||
|
||||
if not prompt:
|
||||
await send_fn({"type": "error", "message": "prompt is required"})
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
parsed = parse_user_input(prompt)
|
||||
task_id = orch.submit_task(goal=prompt[:200])
|
||||
head_ids = select_heads_for_complexity(prompt)
|
||||
if parsed.intent.value == "head_strategy" and parsed.head_id:
|
||||
head_ids = [parsed.head_id]
|
||||
|
||||
await send_fn({"type": "heads_running", "message": "Heads running…"})
|
||||
|
||||
def run_heads():
|
||||
return run_heads_parallel(orch, task_id, prompt, head_ids=head_ids)
|
||||
|
||||
try:
|
||||
head_outputs = await loop.run_in_executor(executor, run_heads)
|
||||
except Exception as e:
|
||||
await send_fn({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
for ho in head_outputs:
|
||||
await send_fn({
|
||||
"type": "head_complete",
|
||||
"head_id": ho.head_id.value,
|
||||
"summary": ho.summary,
|
||||
})
|
||||
await send_fn({
|
||||
"type": "head_speak",
|
||||
"head_id": ho.head_id.value,
|
||||
"summary": ho.summary,
|
||||
"audio_base64": None,
|
||||
})
|
||||
|
||||
await send_fn({"type": "witness_running", "message": "Witness composing…"})
|
||||
|
||||
def run_wit():
|
||||
return run_witness(orch, task_id, head_outputs, prompt)
|
||||
|
||||
try:
|
||||
final = await loop.run_in_executor(executor, run_wit)
|
||||
except Exception as e:
|
||||
await send_fn({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
if final:
|
||||
await send_fn({
|
||||
"type": "complete",
|
||||
"final_answer": final.final_answer,
|
||||
"transparency_report": final.transparency_report.model_dump(),
|
||||
"head_contributions": final.head_contributions,
|
||||
"confidence_score": final.confidence_score,
|
||||
})
|
||||
store.append_history(session_id, {
|
||||
"prompt": prompt,
|
||||
"final_answer": final.final_answer,
|
||||
"confidence_score": final.confidence_score,
|
||||
})
|
||||
else:
|
||||
await send_fn({"type": "error", "message": "Failed to produce response"})
|
||||
11
fusionagi/config/__init__.py
Normal file
11
fusionagi/config/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Configuration for Dvādaśa heads, voices, and services."""
|
||||
|
||||
from fusionagi.config.head_voices import get_voice_id_for_head, HEAD_VOICE_MAP
|
||||
from fusionagi.config.head_personas import get_persona, HEAD_PERSONAS
|
||||
|
||||
__all__ = [
|
||||
"get_voice_id_for_head",
|
||||
"HEAD_VOICE_MAP",
|
||||
"get_persona",
|
||||
"HEAD_PERSONAS",
|
||||
]
|
||||
23
fusionagi/config/head_personas.py
Normal file
23
fusionagi/config/head_personas.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Per-head persona metadata for expressions and tone."""
|
||||
|
||||
from fusionagi.schemas.head import HeadId
|
||||
|
||||
HEAD_PERSONAS: dict[HeadId, dict[str, str]] = {
|
||||
HeadId.LOGIC: {"expression": "analytical", "tone": "precise"},
|
||||
HeadId.RESEARCH: {"expression": "curious", "tone": "thorough"},
|
||||
HeadId.SYSTEMS: {"expression": "technical", "tone": "architectural"},
|
||||
HeadId.STRATEGY: {"expression": "visionary", "tone": "strategic"},
|
||||
HeadId.PRODUCT: {"expression": "empathetic", "tone": "user-focused"},
|
||||
HeadId.SECURITY: {"expression": "vigilant", "tone": "cautious"},
|
||||
HeadId.SAFETY: {"expression": "protective", "tone": "guardian"},
|
||||
HeadId.RELIABILITY: {"expression": "steady", "tone": "dependable"},
|
||||
HeadId.COST: {"expression": "pragmatic", "tone": "efficient"},
|
||||
HeadId.DATA: {"expression": "structured", "tone": "precise"},
|
||||
HeadId.DEVEX: {"expression": "helpful", "tone": "practical"},
|
||||
HeadId.WITNESS: {"expression": "composed", "tone": "synthesizing"},
|
||||
}
|
||||
|
||||
|
||||
def get_persona(head_id: HeadId) -> dict[str, str]:
|
||||
"""Return persona metadata for a head."""
|
||||
return HEAD_PERSONAS.get(head_id, {"expression": "neutral", "tone": "balanced"})
|
||||
25
fusionagi/config/head_voices.py
Normal file
25
fusionagi/config/head_voices.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Head-to-voice mapping for per-head TTS in Dvādaśa."""
|
||||
|
||||
from fusionagi.schemas.head import HeadId
|
||||
|
||||
# Map each HeadId to a voice profile id (from VoiceLibrary).
|
||||
# In production, create VoiceProfiles and register with VoiceLibrary.
|
||||
HEAD_VOICE_MAP: dict[HeadId, str] = {
|
||||
HeadId.LOGIC: "voice_logic",
|
||||
HeadId.RESEARCH: "voice_research",
|
||||
HeadId.SYSTEMS: "voice_systems",
|
||||
HeadId.STRATEGY: "voice_strategy",
|
||||
HeadId.PRODUCT: "voice_product",
|
||||
HeadId.SECURITY: "voice_security",
|
||||
HeadId.SAFETY: "voice_safety",
|
||||
HeadId.RELIABILITY: "voice_reliability",
|
||||
HeadId.COST: "voice_cost",
|
||||
HeadId.DATA: "voice_data",
|
||||
HeadId.DEVEX: "voice_devex",
|
||||
HeadId.WITNESS: "voice_witness",
|
||||
}
|
||||
|
||||
|
||||
def get_voice_id_for_head(head_id: HeadId) -> str:
|
||||
"""Return voice profile id for a head."""
|
||||
return HEAD_VOICE_MAP.get(head_id, "voice_default")
|
||||
57
fusionagi/core/__init__.py
Normal file
57
fusionagi/core/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Core orchestration: event bus, state manager, orchestrator, goal manager, scheduler, blockers, persistence."""
|
||||
|
||||
from fusionagi.core.event_bus import EventBus
|
||||
from fusionagi.core.state_manager import StateManager
|
||||
from fusionagi.core.orchestrator import (
|
||||
Orchestrator,
|
||||
InvalidStateTransitionError,
|
||||
VALID_STATE_TRANSITIONS,
|
||||
AgentProtocol,
|
||||
)
|
||||
from fusionagi.core.persistence import StateBackend
|
||||
from fusionagi.core.json_file_backend import JsonFileBackend
|
||||
from fusionagi.core.goal_manager import GoalManager
|
||||
from fusionagi.core.scheduler import Scheduler, SchedulerMode, FallbackMode
|
||||
from fusionagi.core.blockers import BlockersAndCheckpoints
|
||||
from fusionagi.core.head_orchestrator import (
|
||||
run_heads_parallel,
|
||||
run_witness,
|
||||
run_dvadasa,
|
||||
run_second_pass,
|
||||
select_heads_for_complexity,
|
||||
extract_sources_from_head_outputs,
|
||||
MVP_HEADS,
|
||||
ALL_CONTENT_HEADS,
|
||||
)
|
||||
from fusionagi.core.super_big_brain import (
|
||||
run_super_big_brain,
|
||||
SuperBigBrainConfig,
|
||||
SuperBigBrainReasoningProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"StateManager",
|
||||
"Orchestrator",
|
||||
"StateBackend",
|
||||
"JsonFileBackend",
|
||||
"InvalidStateTransitionError",
|
||||
"VALID_STATE_TRANSITIONS",
|
||||
"AgentProtocol",
|
||||
"GoalManager",
|
||||
"Scheduler",
|
||||
"SchedulerMode",
|
||||
"FallbackMode",
|
||||
"BlockersAndCheckpoints",
|
||||
"run_heads_parallel",
|
||||
"run_witness",
|
||||
"run_dvadasa",
|
||||
"run_second_pass",
|
||||
"select_heads_for_complexity",
|
||||
"extract_sources_from_head_outputs",
|
||||
"MVP_HEADS",
|
||||
"ALL_CONTENT_HEADS",
|
||||
"run_super_big_brain",
|
||||
"SuperBigBrainConfig",
|
||||
"SuperBigBrainReasoningProvider",
|
||||
]
|
||||
35
fusionagi/core/blockers.py
Normal file
35
fusionagi/core/blockers.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Blockers and checkpoints for AGI state machine."""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from fusionagi.schemas.goal import Blocker, Checkpoint
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class BlockersAndCheckpoints:
|
||||
"""Tracks blockers (why stuck) and checkpoints (resumable points)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._blockers: dict[str, list[Blocker]] = {}
|
||||
self._checkpoints: dict[str, list[Checkpoint]] = {}
|
||||
|
||||
def add_blocker(self, blocker: Blocker) -> None:
|
||||
self._blockers.setdefault(blocker.task_id, []).append(blocker)
|
||||
logger.info("Blocker added", extra={"task_id": blocker.task_id, "reason": blocker.reason[:80] if blocker.reason else ""})
|
||||
|
||||
def get_blockers(self, task_id: str) -> list[Blocker]:
|
||||
return list(self._blockers.get(task_id, []))
|
||||
|
||||
def clear_blockers(self, task_id: str) -> None:
|
||||
self._blockers.pop(task_id, None)
|
||||
|
||||
def add_checkpoint(self, checkpoint: Checkpoint) -> None:
|
||||
self._checkpoints.setdefault(checkpoint.task_id, []).append(checkpoint)
|
||||
logger.debug("Checkpoint added", extra={"task_id": checkpoint.task_id})
|
||||
|
||||
def get_latest_checkpoint(self, task_id: str) -> Checkpoint | None:
|
||||
lst = self._checkpoints.get(task_id, [])
|
||||
return lst[-1] if lst else None
|
||||
|
||||
def list_checkpoints(self, task_id: str) -> list[Checkpoint]:
|
||||
return list(self._checkpoints.get(task_id, []))
|
||||
77
fusionagi/core/event_bus.py
Normal file
77
fusionagi/core/event_bus.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""In-process pub/sub event bus for task lifecycle and agent messages."""
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi._time import utc_now_iso
|
||||
|
||||
# Type for event handlers: (event_type, payload) -> None
|
||||
EventHandler = Callable[[str, dict[str, Any]], None]
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""Simple in-process event bus: event type -> list of handlers; optional event history."""
|
||||
|
||||
def __init__(self, history_size: int = 0) -> None:
|
||||
"""
|
||||
Initialize event bus.
|
||||
|
||||
Args:
|
||||
history_size: If > 0, keep the last N events for get_recent_events().
|
||||
"""
|
||||
self._handlers: dict[str, list[EventHandler]] = defaultdict(list)
|
||||
self._history_size = max(0, history_size)
|
||||
self._history: deque[dict[str, Any]] = deque(maxlen=self._history_size) if self._history_size else deque()
|
||||
|
||||
def subscribe(self, event_type: str, handler: EventHandler) -> None:
|
||||
"""Register a handler for an event type."""
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
def unsubscribe(self, event_type: str, handler: EventHandler) -> None:
|
||||
"""Remove a handler for an event type."""
|
||||
if event_type in self._handlers:
|
||||
try:
|
||||
self._handlers[event_type].remove(handler)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def publish(self, event_type: str, payload: dict[str, Any] | None = None) -> None:
|
||||
"""Publish an event; all registered handlers are invoked."""
|
||||
payload = payload or {}
|
||||
if self._history_size > 0:
|
||||
self._history.append({
|
||||
"event_type": event_type,
|
||||
"payload": dict(payload),
|
||||
"timestamp": utc_now_iso(),
|
||||
})
|
||||
task_id = payload.get("task_id", "")
|
||||
logger.debug(
|
||||
"Event published",
|
||||
extra={"event_type": event_type, "task_id": task_id},
|
||||
)
|
||||
for h in self._handlers[event_type][:]:
|
||||
try:
|
||||
h(event_type, payload)
|
||||
except Exception:
|
||||
# Log and continue so one handler failure doesn't block others
|
||||
logger.exception(
|
||||
"Event handler failed",
|
||||
extra={"event_type": event_type},
|
||||
)
|
||||
|
||||
def get_recent_events(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""Return the most recent events (oldest first in slice). Only available if history_size > 0."""
|
||||
if self._history_size == 0:
|
||||
return []
|
||||
events = list(self._history)
|
||||
return events[-limit:] if limit else events
|
||||
|
||||
def clear(self, event_type: str | None = None) -> None:
|
||||
"""Clear handlers for one event type or all; clear history when clearing all."""
|
||||
if event_type is None:
|
||||
self._handlers.clear()
|
||||
if self._history:
|
||||
self._history.clear()
|
||||
elif event_type in self._handlers:
|
||||
del self._handlers[event_type]
|
||||
82
fusionagi/core/goal_manager.py
Normal file
82
fusionagi/core/goal_manager.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Goal manager: objectives, priorities, constraints, time/compute budget for AGI."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.goal import Goal, GoalBudget, GoalStatus
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class GoalManager:
|
||||
"""
|
||||
Manages goals with budgets. Tracks time/compute and can signal
|
||||
when a goal is over budget (abort or degrade).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._goals: dict[str, Goal] = {}
|
||||
self._budget_used: dict[str, dict[str, float]] = {} # goal_id -> {time_used, compute_used}
|
||||
|
||||
def add_goal(self, goal: Goal) -> None:
|
||||
"""Register a goal."""
|
||||
self._goals[goal.goal_id] = goal
|
||||
self._budget_used[goal.goal_id] = {"time_used": 0.0, "compute_used": 0.0}
|
||||
logger.info("Goal added", extra={"goal_id": goal.goal_id, "objective": goal.objective[:80]})
|
||||
|
||||
def get_goal(self, goal_id: str) -> Goal | None:
|
||||
"""Return goal by id or None."""
|
||||
return self._goals.get(goal_id)
|
||||
|
||||
def set_status(self, goal_id: str, status: GoalStatus) -> None:
|
||||
"""Update goal status."""
|
||||
g = self._goals.get(goal_id)
|
||||
if g:
|
||||
self._goals[goal_id] = g.model_copy(update={"status": status})
|
||||
logger.debug("Goal status set", extra={"goal_id": goal_id, "status": status.value})
|
||||
|
||||
def record_time(self, goal_id: str, seconds: float) -> None:
|
||||
"""Record elapsed time for a goal; check budget."""
|
||||
if goal_id not in self._budget_used:
|
||||
self._budget_used[goal_id] = {"time_used": 0.0, "compute_used": 0.0}
|
||||
self._budget_used[goal_id]["time_used"] += seconds
|
||||
self._check_budget(goal_id)
|
||||
|
||||
def record_compute(self, goal_id: str, units: float) -> None:
|
||||
"""Record compute units for a goal; check budget."""
|
||||
if goal_id not in self._budget_used:
|
||||
self._budget_used[goal_id] = {"time_used": 0.0, "compute_used": 0.0}
|
||||
self._budget_used[goal_id]["compute_used"] += units
|
||||
self._check_budget(goal_id)
|
||||
|
||||
def _check_budget(self, goal_id: str) -> None:
|
||||
"""If over budget, set goal to blocked/suspended and log."""
|
||||
g = self._goals.get(goal_id)
|
||||
if not g or not g.budget:
|
||||
return
|
||||
used = self._budget_used.get(goal_id, {})
|
||||
over = False
|
||||
if g.budget.time_seconds is not None and used.get("time_used", 0) >= g.budget.time_seconds:
|
||||
over = True
|
||||
if g.budget.compute_budget is not None and used.get("compute_used", 0) >= g.budget.compute_budget:
|
||||
over = True
|
||||
if over:
|
||||
self.set_status(goal_id, GoalStatus.BLOCKED)
|
||||
logger.warning("Goal over budget", extra={"goal_id": goal_id, "used": used})
|
||||
|
||||
def is_over_budget(self, goal_id: str) -> bool:
|
||||
"""Return True if goal has exceeded its budget."""
|
||||
g = self._goals.get(goal_id)
|
||||
if not g or not g.budget:
|
||||
return False
|
||||
used = self._budget_used.get(goal_id, {})
|
||||
if g.budget.time_seconds is not None and used.get("time_used", 0) >= g.budget.time_seconds:
|
||||
return True
|
||||
if g.budget.compute_budget is not None and used.get("compute_used", 0) >= g.budget.compute_budget:
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_goals(self, status: GoalStatus | None = None) -> list[Goal]:
|
||||
"""Return goals, optionally filtered by status."""
|
||||
goals = list(self._goals.values())
|
||||
if status is not None:
|
||||
goals = [g for g in goals if g.status == status]
|
||||
return goals
|
||||
339
fusionagi/core/head_orchestrator.py
Normal file
339
fusionagi/core/head_orchestrator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Dvādaśa head orchestrator: parallel head dispatch, Witness coordination, second-pass."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.orchestrator import Orchestrator
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput
|
||||
from fusionagi.schemas.witness import FinalResponse
|
||||
from fusionagi.schemas.commands import ParsedCommand, UserIntent
|
||||
from fusionagi._logger import logger
|
||||
|
||||
# MVP: 5 heads. Full: 11.
|
||||
MVP_HEADS: list[HeadId] = [
|
||||
HeadId.LOGIC,
|
||||
HeadId.RESEARCH,
|
||||
HeadId.STRATEGY,
|
||||
HeadId.SECURITY,
|
||||
HeadId.SAFETY,
|
||||
]
|
||||
|
||||
ALL_CONTENT_HEADS: list[HeadId] = [h for h in HeadId if h != HeadId.WITNESS]
|
||||
|
||||
# Heads for second-pass when risk/conflict/security
|
||||
SECOND_PASS_HEADS: list[HeadId] = [HeadId.SECURITY, HeadId.SAFETY, HeadId.LOGIC]
|
||||
|
||||
# Thresholds for automatic second-pass
|
||||
SECOND_PASS_CONFIG: dict[str, Any] = {
|
||||
"min_confidence": 0.5,
|
||||
"max_disputed": 3,
|
||||
"security_keywords": ("security", "risk", "threat", "vulnerability"),
|
||||
}
|
||||
|
||||
|
||||
def run_heads_parallel(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
user_prompt: str,
|
||||
head_ids: list[HeadId] | None = None,
|
||||
sender: str = "head_orchestrator",
|
||||
timeout_per_head: float = 60.0,
|
||||
min_heads_ratio: float | None = 0.6,
|
||||
) -> list[HeadOutput]:
|
||||
"""
|
||||
Dispatch head_request to multiple heads in parallel; collect HeadOutput.
|
||||
|
||||
Args:
|
||||
orchestrator: Orchestrator with registered head agents.
|
||||
task_id: Task identifier.
|
||||
user_prompt: User's prompt/question.
|
||||
head_ids: Heads to run (default: MVP_HEADS).
|
||||
sender: Sender identity for messages.
|
||||
timeout_per_head: Max seconds per head.
|
||||
min_heads_ratio: Return early once this fraction of heads respond (0.6 = 60%).
|
||||
None = wait for all heads. Reduces latency when some heads are slow.
|
||||
|
||||
Returns:
|
||||
List of HeadOutput (may be partial on timeout/failure).
|
||||
"""
|
||||
heads = head_ids or MVP_HEADS
|
||||
heads = [h for h in heads if h != HeadId.WITNESS]
|
||||
if not heads:
|
||||
return []
|
||||
|
||||
envelopes = [
|
||||
AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=sender,
|
||||
recipient=hid.value,
|
||||
intent="head_request",
|
||||
payload={"prompt": user_prompt},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
for hid in heads
|
||||
]
|
||||
|
||||
results: list[HeadOutput] = []
|
||||
min_required = (
|
||||
max(1, math.ceil(len(heads) * min_heads_ratio))
|
||||
if min_heads_ratio is not None
|
||||
else len(heads)
|
||||
)
|
||||
|
||||
def run_one(env: AgentMessageEnvelope) -> HeadOutput | None:
|
||||
resp = orchestrator.route_message_return(env)
|
||||
if resp is None or resp.message.intent != "head_output":
|
||||
return None
|
||||
payload = resp.message.payload or {}
|
||||
ho = payload.get("head_output")
|
||||
if not isinstance(ho, dict):
|
||||
return None
|
||||
try:
|
||||
return HeadOutput.model_validate(ho)
|
||||
except Exception as e:
|
||||
logger.warning("HeadOutput parse failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(heads)) as ex:
|
||||
future_to_env = {ex.submit(run_one, env): env for env in envelopes}
|
||||
for future in as_completed(future_to_env, timeout=timeout_per_head * len(heads) + 5):
|
||||
try:
|
||||
out = future.result(timeout=1)
|
||||
if out is not None:
|
||||
results.append(out)
|
||||
if len(results) >= min_required:
|
||||
logger.info(
|
||||
"Early exit: sufficient heads responded",
|
||||
extra={"responded": len(results), "required": min_required},
|
||||
)
|
||||
break
|
||||
except FuturesTimeoutError:
|
||||
env = future_to_env[future]
|
||||
logger.warning("Head timeout", extra={"recipient": env.message.recipient})
|
||||
except Exception as e:
|
||||
logger.exception("Head execution failed", extra={"error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_witness(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
head_outputs: list[HeadOutput],
|
||||
user_prompt: str,
|
||||
sender: str = "head_orchestrator",
|
||||
) -> FinalResponse | None:
|
||||
"""
|
||||
Route head outputs to Witness; return FinalResponse.
|
||||
"""
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=sender,
|
||||
recipient=HeadId.WITNESS.value,
|
||||
intent="witness_request",
|
||||
payload={
|
||||
"head_outputs": [h.model_dump() for h in head_outputs],
|
||||
"prompt": user_prompt,
|
||||
},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
resp = orchestrator.route_message_return(envelope)
|
||||
if resp is None or resp.message.intent != "witness_output":
|
||||
return None
|
||||
payload = resp.message.payload or {}
|
||||
fr = payload.get("final_response")
|
||||
if not isinstance(fr, dict):
|
||||
return None
|
||||
try:
|
||||
return FinalResponse.model_validate(fr)
|
||||
except Exception as e:
|
||||
logger.warning("FinalResponse parse failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
|
||||
def run_second_pass(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
user_prompt: str,
|
||||
initial_outputs: list[HeadOutput],
|
||||
head_ids: list[HeadId] | None = None,
|
||||
timeout_per_head: float = 60.0,
|
||||
) -> list[HeadOutput]:
|
||||
"""
|
||||
Run second-pass heads (Security, Safety, Logic) and merge with initial outputs.
|
||||
Replaces outputs from second-pass heads with new ones.
|
||||
"""
|
||||
heads = head_ids or SECOND_PASS_HEADS
|
||||
heads = [h for h in heads if h != HeadId.WITNESS]
|
||||
if not heads:
|
||||
return initial_outputs
|
||||
|
||||
second_outputs = run_heads_parallel(
|
||||
orchestrator,
|
||||
task_id,
|
||||
user_prompt,
|
||||
head_ids=heads,
|
||||
timeout_per_head=timeout_per_head,
|
||||
)
|
||||
|
||||
by_head: dict[HeadId, HeadOutput] = {o.head_id: o for o in initial_outputs}
|
||||
for o in second_outputs:
|
||||
by_head[o.head_id] = o
|
||||
return list(by_head.values())
|
||||
|
||||
|
||||
def _should_run_second_pass(
|
||||
final: FinalResponse,
|
||||
force: bool = False,
|
||||
second_pass_config: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Check if second-pass should run based on transparency report."""
|
||||
if force:
|
||||
return True
|
||||
cfg = {**SECOND_PASS_CONFIG, **(second_pass_config or {})}
|
||||
am = final.transparency_report.agreement_map
|
||||
if am.confidence_score < cfg.get("min_confidence", 0.5):
|
||||
return True
|
||||
if len(am.disputed_claims) > cfg.get("max_disputed", 3):
|
||||
return True
|
||||
sr = (final.transparency_report.safety_report or "").lower()
|
||||
if any(kw in sr for kw in cfg.get("security_keywords", ())):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def run_dvadasa(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
user_prompt: str,
|
||||
parsed: ParsedCommand | None = None,
|
||||
head_ids: list[HeadId] | None = None,
|
||||
timeout_per_head: float = 60.0,
|
||||
event_bus: Any | None = None,
|
||||
force_second_pass: bool = False,
|
||||
return_head_outputs: bool = False,
|
||||
second_pass_config: dict[str, Any] | None = None,
|
||||
min_heads_ratio: float | None = 0.6,
|
||||
) -> FinalResponse | tuple[FinalResponse, list[HeadOutput]] | tuple[None, list[HeadOutput]] | None:
|
||||
"""
|
||||
Full Dvādaśa flow: run heads in parallel, then Witness.
|
||||
|
||||
Args:
|
||||
orchestrator: Orchestrator with heads and witness registered.
|
||||
task_id: Task identifier.
|
||||
user_prompt: User's prompt (or use parsed.cleaned_prompt when HEAD_STRATEGY).
|
||||
parsed: Optional ParsedCommand from parse_user_input.
|
||||
head_ids: Override heads to run (e.g. single head for HEAD_STRATEGY).
|
||||
timeout_per_head: Max seconds per head.
|
||||
event_bus: Optional EventBus to publish dvadasa_complete.
|
||||
second_pass_config: Override SECOND_PASS_CONFIG (min_confidence, max_disputed, etc).
|
||||
min_heads_ratio: Early exit once this fraction of heads respond; None = wait all.
|
||||
|
||||
Returns:
|
||||
FinalResponse or None on failure.
|
||||
"""
|
||||
prompt = user_prompt
|
||||
heads = head_ids
|
||||
|
||||
if parsed:
|
||||
if parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id and parsed.cleaned_prompt:
|
||||
prompt = parsed.cleaned_prompt
|
||||
heads = [parsed.head_id]
|
||||
elif parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id:
|
||||
heads = [parsed.head_id]
|
||||
|
||||
if heads is None:
|
||||
heads = select_heads_for_complexity(prompt)
|
||||
|
||||
head_outputs = run_heads_parallel(
|
||||
orchestrator,
|
||||
task_id,
|
||||
prompt,
|
||||
head_ids=heads,
|
||||
timeout_per_head=timeout_per_head,
|
||||
)
|
||||
if not head_outputs:
|
||||
logger.warning("No head outputs; cannot run Witness")
|
||||
return (None, []) if return_head_outputs else None
|
||||
|
||||
final = run_witness(orchestrator, task_id, head_outputs, prompt)
|
||||
|
||||
if final and (
|
||||
force_second_pass
|
||||
or _should_run_second_pass(
|
||||
final, force=force_second_pass, second_pass_config=second_pass_config
|
||||
)
|
||||
):
|
||||
head_outputs = run_second_pass(
|
||||
orchestrator,
|
||||
task_id,
|
||||
prompt,
|
||||
head_outputs,
|
||||
timeout_per_head=timeout_per_head,
|
||||
)
|
||||
final = run_witness(orchestrator, task_id, head_outputs, prompt)
|
||||
|
||||
if final and event_bus:
|
||||
try:
|
||||
event_bus.publish(
|
||||
"dvadasa_complete",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"final_response": final.model_dump(),
|
||||
"head_count": len(head_outputs),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish dvadasa_complete", extra={"error": str(e)})
|
||||
|
||||
if return_head_outputs:
|
||||
return (final, head_outputs)
|
||||
return final
|
||||
|
||||
|
||||
def extract_sources_from_head_outputs(head_outputs: list[HeadOutput]) -> list[dict[str, Any]]:
|
||||
"""Extract citations from head outputs for SOURCES command."""
|
||||
sources: list[dict[str, Any]] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for ho in head_outputs:
|
||||
for claim in ho.claims:
|
||||
for ev in claim.evidence:
|
||||
key = (ho.head_id.value, ev.source_id or "")
|
||||
if key in seen or not ev.source_id:
|
||||
continue
|
||||
seen.add(key)
|
||||
sources.append({
|
||||
"head_id": ho.head_id.value,
|
||||
"source_id": ev.source_id,
|
||||
"excerpt": ev.excerpt or "",
|
||||
"confidence": ev.confidence,
|
||||
})
|
||||
return sources
|
||||
|
||||
|
||||
def select_heads_for_complexity(
|
||||
prompt: str,
|
||||
mvp_heads: list[HeadId] = MVP_HEADS,
|
||||
all_heads: list[HeadId] | None = None,
|
||||
) -> list[HeadId]:
|
||||
"""
|
||||
Dynamic routing: simple prompts use fewer heads.
|
||||
Heuristic: long prompt or keywords => all heads.
|
||||
"""
|
||||
all_heads = all_heads or ALL_CONTENT_HEADS
|
||||
prompt_lower = prompt.lower()
|
||||
complex_keywords = (
|
||||
"security", "risk", "architecture", "scalability", "compliance",
|
||||
"critical", "production", "audit", "privacy", "sensitive",
|
||||
)
|
||||
if len(prompt.split()) > 50 or any(kw in prompt_lower for kw in complex_keywords):
|
||||
return all_heads
|
||||
return mvp_heads
|
||||
69
fusionagi/core/json_file_backend.py
Normal file
69
fusionagi/core/json_file_backend.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""JSON file persistence backend for StateManager."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState
|
||||
from fusionagi.core.persistence import StateBackend
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class JsonFileBackend(StateBackend):
|
||||
"""
|
||||
StateBackend that persists tasks and traces to a JSON file.
|
||||
|
||||
Use with StateManager(backend=JsonFileBackend(path="state.json")).
|
||||
File is created on first write; directory must exist or be creatable.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str | Path) -> None:
|
||||
self._path = Path(path)
|
||||
self._tasks: dict[str, dict[str, Any]] = {}
|
||||
self._traces: dict[str, list[dict[str, Any]]] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
if not self._path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
self._tasks = data.get("tasks", {})
|
||||
self._traces = data.get("traces", {})
|
||||
except Exception as e:
|
||||
logger.warning("JsonFileBackend load failed", extra={"path": str(self._path), "error": str(e)})
|
||||
|
||||
def _save(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {"tasks": self._tasks, "traces": self._traces}
|
||||
self._path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
raw = self._tasks.get(task_id)
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return Task.model_validate(raw)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def set_task(self, task: Task) -> None:
|
||||
self._tasks[task.task_id] = task.model_dump(mode="json")
|
||||
self._save()
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
task = self.get_task(task_id)
|
||||
return task.state if task else None
|
||||
|
||||
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||||
task = self.get_task(task_id)
|
||||
if task:
|
||||
updated = task.model_copy(update={"state": state})
|
||||
self.set_task(updated)
|
||||
|
||||
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||||
self._traces.setdefault(task_id, []).append(entry)
|
||||
self._save()
|
||||
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||||
return list(self._traces.get(task_id, []))
|
||||
310
fusionagi/core/orchestrator.py
Normal file
310
fusionagi/core/orchestrator.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Orchestrator: task lifecycle, agent registry, wiring to event bus and state."""
|
||||
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState, TaskPriority, VALID_TASK_TRANSITIONS
|
||||
from fusionagi.schemas.messages import AgentMessageEnvelope
|
||||
|
||||
from fusionagi.core.event_bus import EventBus
|
||||
from fusionagi.core.state_manager import StateManager
|
||||
from fusionagi._logger import logger
|
||||
|
||||
# Single source of truth: re-export from schemas for backward compatibility
|
||||
VALID_STATE_TRANSITIONS = VALID_TASK_TRANSITIONS
|
||||
|
||||
|
||||
class InvalidStateTransitionError(Exception):
|
||||
"""Raised when an invalid state transition is attempted."""
|
||||
|
||||
def __init__(self, task_id: str, from_state: TaskState, to_state: TaskState) -> None:
|
||||
self.task_id = task_id
|
||||
self.from_state = from_state
|
||||
self.to_state = to_state
|
||||
super().__init__(
|
||||
f"Invalid state transition for task {task_id}: {from_state.value} -> {to_state.value}"
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentProtocol(Protocol):
|
||||
"""Protocol for agents that can handle messages."""
|
||||
|
||||
identity: str
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""Handle an incoming message and optionally return a response."""
|
||||
...
|
||||
|
||||
|
||||
class TaskGraphEntry(BaseModel):
|
||||
"""Per-task plan/metadata storage (plan cache)."""
|
||||
|
||||
plan: dict[str, Any] | None = Field(default=None, description="Stored plan for the task")
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
"""
|
||||
Global task lifecycle and agent coordination; holds task plans, event bus, state, agent registry.
|
||||
|
||||
Task state lifecycle: submit_task creates PENDING. Callers/supervisors must call set_task_state
|
||||
to transition to ACTIVE, COMPLETED, FAILED, or CANCELLED. The orchestrator validates state
|
||||
transitions according to VALID_STATE_TRANSITIONS.
|
||||
|
||||
Valid transitions:
|
||||
PENDING -> ACTIVE, CANCELLED
|
||||
ACTIVE -> COMPLETED, FAILED, CANCELLED
|
||||
FAILED -> PENDING (retry), CANCELLED
|
||||
COMPLETED -> (terminal)
|
||||
CANCELLED -> (terminal)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus,
|
||||
state_manager: StateManager,
|
||||
validate_transitions: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the orchestrator.
|
||||
|
||||
Args:
|
||||
event_bus: Event bus for publishing events.
|
||||
state_manager: State manager for task state.
|
||||
validate_transitions: If True, validate state transitions (default True).
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._state = state_manager
|
||||
self._validate_transitions = validate_transitions
|
||||
self._agents: dict[str, AgentProtocol | Any] = {} # agent_id -> agent instance
|
||||
self._sub_agents: dict[str, list[str]] = {} # parent_id -> [child_id]
|
||||
self._task_plans: dict[str, TaskGraphEntry] = {} # task_id -> plan/metadata per task
|
||||
self._async_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="orch_async")
|
||||
|
||||
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||
"""Register an agent by id for routing and assignment."""
|
||||
self._agents[agent_id] = agent
|
||||
logger.info("Agent registered", extra={"agent_id": agent_id})
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""Remove an agent from the registry and from any parent's sub-agent list."""
|
||||
self._agents.pop(agent_id, None)
|
||||
self._sub_agents.pop(agent_id, None)
|
||||
for parent_id, children in list(self._sub_agents.items()):
|
||||
if agent_id in children:
|
||||
self._sub_agents[parent_id] = [c for c in children if c != agent_id]
|
||||
logger.info("Agent unregistered", extra={"agent_id": agent_id})
|
||||
|
||||
def register_sub_agent(self, parent_id: str, child_id: str, agent: Any) -> None:
|
||||
"""Register a sub-agent under a parent; child can be delegated sub-tasks."""
|
||||
self._agents[child_id] = agent
|
||||
self._sub_agents.setdefault(parent_id, []).append(child_id)
|
||||
logger.info("Sub-agent registered", extra={"parent_id": parent_id, "child_id": child_id})
|
||||
|
||||
def get_sub_agents(self, parent_id: str) -> list[str]:
|
||||
"""Return list of child agent ids for a parent."""
|
||||
return list(self._sub_agents.get(parent_id, []))
|
||||
|
||||
def get_agent(self, agent_id: str) -> Any | None:
|
||||
"""Return registered agent by id or None."""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shut down the async executor used for route_message_async. Call when the orchestrator is no longer needed."""
|
||||
self._async_executor.shutdown(wait=wait)
|
||||
logger.debug("Orchestrator async executor shut down", extra={"wait": wait})
|
||||
|
||||
def submit_task(
|
||||
self,
|
||||
goal: str,
|
||||
constraints: list[str] | None = None,
|
||||
priority: TaskPriority = TaskPriority.NORMAL,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Create a task and publish task_created; returns task_id."""
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
goal=goal,
|
||||
constraints=constraints or [],
|
||||
priority=priority,
|
||||
state=TaskState.PENDING,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._state.set_task(task)
|
||||
self._task_plans[task_id] = TaskGraphEntry()
|
||||
logger.info(
|
||||
"Task created",
|
||||
extra={"task_id": task_id, "goal": goal[:200] if goal else ""},
|
||||
)
|
||||
self._event_bus.publish(
|
||||
"task_created",
|
||||
{"task_id": task_id, "goal": goal, "constraints": task.constraints},
|
||||
)
|
||||
return task_id
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
"""Return current state of a task or None if unknown."""
|
||||
return self._state.get_task_state(task_id)
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
"""Return full task or None."""
|
||||
return self._state.get_task(task_id)
|
||||
|
||||
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None:
|
||||
"""Store plan in task plans for a task."""
|
||||
if task_id in self._task_plans:
|
||||
self._task_plans[task_id].plan = plan
|
||||
|
||||
def get_task_plan(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Return stored plan for a task or None."""
|
||||
entry = self._task_plans.get(task_id)
|
||||
return entry.plan if entry else None
|
||||
|
||||
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None:
|
||||
"""
|
||||
Update task state with transition validation.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier.
|
||||
state: The new state to transition to.
|
||||
force: If True, skip transition validation (use with caution).
|
||||
|
||||
Raises:
|
||||
InvalidStateTransitionError: If the transition is not allowed and force=False.
|
||||
ValueError: If task_id is unknown.
|
||||
"""
|
||||
current_state = self._state.get_task_state(task_id)
|
||||
if current_state is None:
|
||||
raise ValueError(f"Unknown task: {task_id}")
|
||||
|
||||
if not force and self._validate_transitions:
|
||||
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
|
||||
if state not in allowed and state != current_state:
|
||||
raise InvalidStateTransitionError(task_id, current_state, state)
|
||||
|
||||
self._state.set_task_state(task_id, state)
|
||||
logger.debug(
|
||||
"Task state set",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"from_state": current_state.value,
|
||||
"to_state": state.value,
|
||||
},
|
||||
)
|
||||
self._event_bus.publish(
|
||||
"task_state_changed",
|
||||
{"task_id": task_id, "from_state": current_state.value, "to_state": state.value},
|
||||
)
|
||||
|
||||
def can_transition(self, task_id: str, state: TaskState) -> bool:
|
||||
"""Check if a state transition is valid without performing it."""
|
||||
current_state = self._state.get_task_state(task_id)
|
||||
if current_state is None:
|
||||
return False
|
||||
if state == current_state:
|
||||
return True
|
||||
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
|
||||
return state in allowed
|
||||
|
||||
def route_message(self, envelope: AgentMessageEnvelope) -> None:
|
||||
"""
|
||||
Deliver an envelope to the recipient agent and publish message_received.
|
||||
Does not route the agent's response; use route_message_return to get and optionally
|
||||
re-route the response envelope.
|
||||
"""
|
||||
recipient = envelope.message.recipient
|
||||
intent = envelope.message.intent
|
||||
task_id = envelope.task_id or ""
|
||||
logger.info(
|
||||
"Message routed",
|
||||
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
|
||||
)
|
||||
agent = self._agents.get(recipient)
|
||||
self._event_bus.publish(
|
||||
"message_received",
|
||||
{
|
||||
"task_id": envelope.task_id,
|
||||
"recipient": recipient,
|
||||
"intent": intent,
|
||||
},
|
||||
)
|
||||
if agent is not None and hasattr(agent, "handle_message"):
|
||||
agent.handle_message(envelope)
|
||||
|
||||
def route_message_return(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""
|
||||
Deliver an envelope to the recipient agent and return the response envelope, if any.
|
||||
Use this when the caller needs to handle or re-route the agent's response.
|
||||
"""
|
||||
recipient = envelope.message.recipient
|
||||
intent = envelope.message.intent
|
||||
task_id = envelope.task_id or ""
|
||||
logger.info(
|
||||
"Message routed",
|
||||
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
|
||||
)
|
||||
agent = self._agents.get(recipient)
|
||||
self._event_bus.publish(
|
||||
"message_received",
|
||||
{
|
||||
"task_id": envelope.task_id,
|
||||
"recipient": recipient,
|
||||
"intent": intent,
|
||||
},
|
||||
)
|
||||
if agent is not None and hasattr(agent, "handle_message"):
|
||||
return agent.handle_message(envelope)
|
||||
return None
|
||||
|
||||
def route_messages_batch(
|
||||
self,
|
||||
envelopes: list[AgentMessageEnvelope],
|
||||
) -> list[AgentMessageEnvelope | None]:
|
||||
"""
|
||||
Route multiple messages; return responses in same order.
|
||||
Uses concurrent execution for parallel dispatch.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
results: list[AgentMessageEnvelope | None] = [None] * len(envelopes)
|
||||
|
||||
def route_one(i: int, env: AgentMessageEnvelope) -> tuple[int, AgentMessageEnvelope | None]:
|
||||
return i, self.route_message_return(env)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(len(envelopes), 32)) as ex:
|
||||
futures = [ex.submit(route_one, i, env) for i, env in enumerate(envelopes)]
|
||||
for fut in as_completed(futures):
|
||||
idx, resp = fut.result()
|
||||
results[idx] = resp
|
||||
|
||||
return results
|
||||
|
||||
def route_message_async(
|
||||
self,
|
||||
envelope: AgentMessageEnvelope,
|
||||
callback: Callable[[AgentMessageEnvelope | None], None] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Route message in background; optionally invoke callback with response.
|
||||
Returns Future for non-blocking await.
|
||||
"""
|
||||
from concurrent import futures
|
||||
|
||||
def run() -> AgentMessageEnvelope | None:
|
||||
return self.route_message_return(envelope)
|
||||
|
||||
future = self._async_executor.submit(run)
|
||||
if callback:
|
||||
|
||||
def done(f: futures.Future) -> None:
|
||||
try:
|
||||
callback(f.result())
|
||||
except Exception:
|
||||
logger.exception("Async route callback failed")
|
||||
|
||||
future.add_done_callback(done)
|
||||
return future
|
||||
44
fusionagi/core/persistence.py
Normal file
44
fusionagi/core/persistence.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Optional persistence interface for state manager; in-memory is default."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState
|
||||
|
||||
|
||||
class StateBackend(ABC):
|
||||
"""
|
||||
Abstract backend for task state and traces; replace StateManager internals for persistence.
|
||||
Any backend used to replace StateManager storage must implement get_task_state and set_task_state
|
||||
in addition to task and trace methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
"""Load task by id."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_task(self, task: Task) -> None:
|
||||
"""Save task."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
"""Return current task state or None if task unknown."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||||
"""Update task state; creates no task if missing."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||||
"""Append trace entry."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||||
"""Load trace for task."""
|
||||
...
|
||||
89
fusionagi/core/scheduler.py
Normal file
89
fusionagi/core/scheduler.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Scheduler: think vs act, tool selection, retry logic, fallback modes for AGI."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class SchedulerMode(str, Enum):
|
||||
"""Whether to think (reason) or act (tool) next."""
|
||||
|
||||
THINK = "think"
|
||||
ACT = "act"
|
||||
|
||||
|
||||
class FallbackMode(str, Enum):
|
||||
"""Fallback when primary path fails."""
|
||||
|
||||
RETRY = "retry"
|
||||
SIMPLIFY_PLAN = "simplify_plan"
|
||||
HUMAN_HANDOFF = "human_handoff"
|
||||
ABORT = "abort"
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Decides think vs act, tool selection policy, retry/backoff, fallback.
|
||||
Callers (e.g. Supervisor) query next_action() and record outcomes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_mode: SchedulerMode = SchedulerMode.ACT,
|
||||
max_retries_per_step: int = 2,
|
||||
fallback_sequence: list[FallbackMode] | None = None,
|
||||
) -> None:
|
||||
self._default_mode = default_mode
|
||||
self._max_retries = max_retries_per_step
|
||||
self._fallback_sequence = fallback_sequence or [
|
||||
FallbackMode.RETRY,
|
||||
FallbackMode.SIMPLIFY_PLAN,
|
||||
FallbackMode.HUMAN_HANDOFF,
|
||||
FallbackMode.ABORT,
|
||||
]
|
||||
self._retry_counts: dict[str, int] = {} # step_key -> count
|
||||
self._fallback_index: dict[str, int] = {} # task_id -> index into fallback_sequence
|
||||
|
||||
def next_mode(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> SchedulerMode:
|
||||
"""
|
||||
Return whether to think (reason more) or act (execute step).
|
||||
Override via context["force_think"] or context["force_act"].
|
||||
"""
|
||||
if context:
|
||||
if context.get("force_think"):
|
||||
return SchedulerMode.THINK
|
||||
if context.get("force_act"):
|
||||
return SchedulerMode.ACT
|
||||
return self._default_mode
|
||||
|
||||
def should_retry(self, task_id: str, step_id: str) -> bool:
|
||||
"""Return True if step should be retried (under max_retries)."""
|
||||
key = f"{task_id}:{step_id}"
|
||||
count = self._retry_counts.get(key, 0)
|
||||
return count < self._max_retries
|
||||
|
||||
def record_retry(self, task_id: str, step_id: str) -> None:
|
||||
"""Increment retry count for step."""
|
||||
key = f"{task_id}:{step_id}"
|
||||
self._retry_counts[key] = self._retry_counts.get(key, 0) + 1
|
||||
logger.debug("Scheduler recorded retry", extra={"task_id": task_id, "step_id": step_id})
|
||||
|
||||
def next_fallback(self, task_id: str) -> FallbackMode | None:
|
||||
"""Return next fallback mode for task, or None if exhausted."""
|
||||
idx = self._fallback_index.get(task_id, 0)
|
||||
if idx >= len(self._fallback_sequence):
|
||||
return None
|
||||
mode = self._fallback_sequence[idx]
|
||||
self._fallback_index[task_id] = idx + 1
|
||||
logger.info("Scheduler fallback", extra={"task_id": task_id, "fallback": mode.value})
|
||||
return mode
|
||||
|
||||
def reset_fallback(self, task_id: str) -> None:
|
||||
"""Reset fallback index for task (e.g. after success)."""
|
||||
self._fallback_index.pop(task_id, None)
|
||||
111
fusionagi/core/state_manager.py
Normal file
111
fusionagi/core/state_manager.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""In-memory store for task state and execution traces; replaceable with persistent backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState
|
||||
from fusionagi._logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.persistence import StateBackend
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""
|
||||
Manages task state and execution traces.
|
||||
|
||||
Supports optional persistent backend via dependency injection. When a backend
|
||||
is provided, all operations are persisted. In-memory cache is always maintained
|
||||
for fast access.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: StateBackend | None = None) -> None:
|
||||
"""
|
||||
Initialize StateManager with optional persistence backend.
|
||||
|
||||
Args:
|
||||
backend: Optional StateBackend for persistence. If None, uses in-memory only.
|
||||
"""
|
||||
self._backend = backend
|
||||
self._tasks: dict[str, Task] = {}
|
||||
self._traces: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
"""Return the task by id or None. Checks memory first, then backend."""
|
||||
if task_id in self._tasks:
|
||||
return self._tasks[task_id]
|
||||
if self._backend:
|
||||
task = self._backend.get_task(task_id)
|
||||
if task:
|
||||
self._tasks[task_id] = task
|
||||
return task
|
||||
return None
|
||||
|
||||
def set_task(self, task: Task) -> None:
|
||||
"""Store or update a task in memory and backend."""
|
||||
self._tasks[task.task_id] = task
|
||||
if self._backend:
|
||||
self._backend.set_task(task)
|
||||
logger.debug("Task set", extra={"task_id": task.task_id})
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
"""Return current task state or None if task unknown."""
|
||||
task = self.get_task(task_id)
|
||||
return task.state if task else None
|
||||
|
||||
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||||
"""Update task state; creates no task if missing."""
|
||||
if task_id in self._tasks:
|
||||
self._tasks[task_id].state = state
|
||||
if self._backend:
|
||||
self._backend.set_task_state(task_id, state)
|
||||
logger.debug("Task state set", extra={"task_id": task_id, "state": state.value})
|
||||
elif self._backend:
|
||||
# Task might be in backend but not in memory
|
||||
task = self._backend.get_task(task_id)
|
||||
if task:
|
||||
task.state = state
|
||||
self._tasks[task_id] = task
|
||||
self._backend.set_task_state(task_id, state)
|
||||
logger.debug("Task state set", extra={"task_id": task_id, "state": state.value})
|
||||
|
||||
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||||
"""Append an entry to the execution trace for a task."""
|
||||
self._traces[task_id].append(entry)
|
||||
if self._backend:
|
||||
self._backend.append_trace(task_id, entry)
|
||||
tool = entry.get("tool") or entry.get("step") or "entry"
|
||||
logger.debug("Trace appended", extra={"task_id": task_id, "entry_key": tool})
|
||||
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||||
"""Return the execution trace for a task (copy). Checks backend if not in memory."""
|
||||
if task_id in self._traces and self._traces[task_id]:
|
||||
return list(self._traces[task_id])
|
||||
if self._backend:
|
||||
trace = self._backend.get_trace(task_id)
|
||||
if trace:
|
||||
self._traces[task_id] = list(trace)
|
||||
return trace
|
||||
return list(self._traces.get(task_id, []))
|
||||
|
||||
def clear_task(self, task_id: str) -> None:
|
||||
"""Remove task and its trace (for tests or cleanup). Does not clear backend."""
|
||||
self._tasks.pop(task_id, None)
|
||||
self._traces.pop(task_id, None)
|
||||
|
||||
def list_tasks(self, state: TaskState | None = None) -> list[Task]:
|
||||
"""Return all tasks, optionally filtered by state.
|
||||
|
||||
When a persistence backend is configured, only tasks currently loaded
|
||||
in memory are returned; tasks that exist only in the backend are not included.
|
||||
"""
|
||||
tasks = list(self._tasks.values())
|
||||
if state is not None:
|
||||
tasks = [t for t in tasks if t.state == state]
|
||||
return tasks
|
||||
|
||||
def task_count(self) -> int:
|
||||
"""Return total number of tasks in memory."""
|
||||
return len(self._tasks)
|
||||
136
fusionagi/core/super_big_brain.py
Normal file
136
fusionagi/core/super_big_brain.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Super Big Brain orchestrator: tokenless, recursive, graph-backed reasoning."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit, DecompositionResult
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim, HeadRisk
|
||||
from fusionagi.schemas.grounding import Citation
|
||||
from fusionagi.reasoning.decomposition import decompose_recursive
|
||||
from fusionagi.reasoning.context_loader import load_context_for_reasoning, build_compact_prompt
|
||||
from fusionagi.reasoning.tot import ThoughtNode, expand_node, prune_subtree, merge_subtrees
|
||||
from fusionagi.reasoning.multi_path import generate_and_score_parallel
|
||||
from fusionagi.reasoning.recomposition import recompose, RecomposedResponse
|
||||
from fusionagi.reasoning.meta_reasoning import challenge_assumptions, detect_contradictions
|
||||
from fusionagi.memory.semantic_graph import SemanticGraphMemory
|
||||
from fusionagi.memory.sharding import shard_context
|
||||
from fusionagi.memory.scratchpad import LatentScratchpad
|
||||
from fusionagi.memory.thought_versioning import ThoughtVersioning
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuperBigBrainConfig:
|
||||
"""Configuration for Super Big Brain pipeline."""
|
||||
|
||||
max_decomposition_depth: int = 3
|
||||
min_depth_before_conclusion: int = 1
|
||||
parallel_hypotheses: int = 3
|
||||
prune_threshold: float = 0.3
|
||||
max_context_chars: int = 4000
|
||||
|
||||
|
||||
def run_super_big_brain(
|
||||
prompt: str,
|
||||
semantic_graph: SemanticGraphMemory,
|
||||
config: SuperBigBrainConfig | None = None,
|
||||
adapter: Any | None = None,
|
||||
) -> RecomposedResponse:
|
||||
"""
|
||||
End-to-end Super Big Brain pipeline:
|
||||
|
||||
1. Decompose prompt -> atomic units
|
||||
2. Shard and load context
|
||||
3. Run hierarchical ToT with multi-path inference
|
||||
4. Recompose with traceability
|
||||
5. Persist units/relations to semantic graph
|
||||
"""
|
||||
cfg = config or SuperBigBrainConfig()
|
||||
decomp = decompose_recursive(prompt, max_depth=cfg.max_decomposition_depth)
|
||||
if not decomp.units:
|
||||
return RecomposedResponse(summary="No content to reason over.", confidence=0.0)
|
||||
|
||||
semantic_graph.ingest_decomposition(decomp.units, decomp.relations)
|
||||
ctx = load_context_for_reasoning(decomp.units, semantic_graph=semantic_graph, sharder=shard_context)
|
||||
compact = build_compact_prompt(decomp.units, max_chars=cfg.max_context_chars)
|
||||
|
||||
hypotheses = [u.content for u in decomp.units[:cfg.parallel_hypotheses] if u.content]
|
||||
if not hypotheses:
|
||||
hypotheses = [compact[:500]]
|
||||
|
||||
scored = generate_and_score_parallel(hypotheses, decomp.units)
|
||||
nodes = [n for n, _ in sorted(scored, key=lambda x: x[1], reverse=True)]
|
||||
best = nodes[0] if nodes else ThoughtNode(thought=compact[:300], unit_refs=[u.unit_id for u in decomp.units[:5]])
|
||||
|
||||
if cfg.min_depth_before_conclusion > 0 and best.depth < cfg.min_depth_before_conclusion:
|
||||
child = expand_node(best, compact[:200], unit_refs=best.unit_refs)
|
||||
child.score = best.score
|
||||
best = child
|
||||
|
||||
prune_subtree(best, cfg.prune_threshold)
|
||||
assumptions = challenge_assumptions(decomp.units, best.thought)
|
||||
contradictions = detect_contradictions(decomp.units)
|
||||
|
||||
recomp = recompose([best], decomp.units)
|
||||
recomp.metadata["assumptions_flagged"] = len(assumptions)
|
||||
recomp.metadata["contradictions"] = len(contradictions)
|
||||
recomp.metadata["depth"] = best.depth
|
||||
|
||||
logger.info(
|
||||
"Super Big Brain complete",
|
||||
extra={"units": len(decomp.units), "confidence": recomp.confidence},
|
||||
)
|
||||
return recomp
|
||||
|
||||
|
||||
def _recomposed_to_head_output(
|
||||
recomp: RecomposedResponse,
|
||||
head_id: HeadId,
|
||||
) -> HeadOutput:
|
||||
"""Convert RecomposedResponse to HeadOutput for Dvādaśa integration."""
|
||||
claims = [
|
||||
HeadClaim(
|
||||
claim_text=c,
|
||||
confidence=recomp.confidence,
|
||||
evidence=[Citation(source_id=uid, excerpt="", confidence=recomp.confidence) for uid in recomp.unit_refs[:3]],
|
||||
assumptions=[],
|
||||
)
|
||||
for c in recomp.key_claims[:5]
|
||||
]
|
||||
if not claims:
|
||||
claims = [
|
||||
HeadClaim(claim_text=recomp.summary, confidence=recomp.confidence, evidence=[], assumptions=[]),
|
||||
]
|
||||
risks = []
|
||||
if recomp.metadata.get("assumptions_flagged", 0) > 0:
|
||||
risks.append(HeadRisk(description="Assumptions flagged; verify before acting", severity="medium"))
|
||||
if recomp.metadata.get("contradictions", 0) > 0:
|
||||
risks.append(HeadRisk(description="Contradictions detected in context", severity="high"))
|
||||
return HeadOutput(
|
||||
head_id=head_id,
|
||||
summary=recomp.summary,
|
||||
claims=claims,
|
||||
risks=risks,
|
||||
questions=[],
|
||||
recommended_actions=["Consider flagged assumptions", "Resolve contradictions if any"],
|
||||
tone_guidance="",
|
||||
)
|
||||
|
||||
|
||||
class SuperBigBrainReasoningProvider:
|
||||
"""ReasoningProvider for HeadAgent: uses Super Big Brain pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
semantic_graph: SemanticGraphMemory | None = None,
|
||||
config: SuperBigBrainConfig | None = None,
|
||||
) -> None:
|
||||
self._graph = semantic_graph or SemanticGraphMemory()
|
||||
self._config = config or SuperBigBrainConfig()
|
||||
|
||||
def produce_head_output(self, head_id: HeadId, prompt: str) -> HeadOutput:
|
||||
"""Produce HeadOutput using Super Big Brain pipeline."""
|
||||
recomp = run_super_big_brain(prompt, self._graph, self._config)
|
||||
return _recomposed_to_head_output(recomp, head_id)
|
||||
32
fusionagi/governance/__init__.py
Normal file
32
fusionagi/governance/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Governance and safety: guardrails, rate limiting, access control, override, audit, policy, intent alignment."""
|
||||
|
||||
from fusionagi.governance.guardrails import Guardrails, PreCheckResult
|
||||
from fusionagi.governance.rate_limiter import RateLimiter
|
||||
from fusionagi.governance.access_control import AccessControl
|
||||
from fusionagi.governance.override import OverrideHooks
|
||||
from fusionagi.governance.audit_log import AuditLog
|
||||
from fusionagi.governance.policy_engine import PolicyEngine
|
||||
from fusionagi.governance.intent_alignment import IntentAlignment
|
||||
from fusionagi.governance.safety_pipeline import (
|
||||
SafetyPipeline,
|
||||
InputModerator,
|
||||
OutputScanner,
|
||||
ModerationResult,
|
||||
OutputScanResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Guardrails",
|
||||
"PreCheckResult",
|
||||
"RateLimiter",
|
||||
"AccessControl",
|
||||
"OverrideHooks",
|
||||
"AuditLog",
|
||||
"PolicyEngine",
|
||||
"IntentAlignment",
|
||||
"SafetyPipeline",
|
||||
"InputModerator",
|
||||
"OutputScanner",
|
||||
"ModerationResult",
|
||||
"OutputScanResult",
|
||||
]
|
||||
30
fusionagi/governance/access_control.py
Normal file
30
fusionagi/governance/access_control.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Tool access control: central policy for which agent may call which tools.
|
||||
|
||||
Optional; not wired to Executor or Orchestrator by default. Wire by passing
|
||||
an AccessControl instance and checking allowed(agent_id, tool_name, task_id)
|
||||
before tool invocation.
|
||||
"""
|
||||
|
||||
|
||||
class AccessControl:
|
||||
"""Policy: (agent_id, tool_name, task_id) -> allowed."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._deny: set[tuple[str, str]] = set()
|
||||
self._task_tools: dict[str, set[str]] = {}
|
||||
|
||||
def deny(self, agent_id: str, tool_name: str) -> None:
|
||||
"""Deny agent from using tool (global)."""
|
||||
self._deny.add((agent_id, tool_name))
|
||||
|
||||
def allow_tools_for_task(self, task_id: str, tool_names: list[str]) -> None:
|
||||
"""Set allowed tools for a task (empty = all allowed)."""
|
||||
self._task_tools[task_id] = set(tool_names)
|
||||
|
||||
def allowed(self, agent_id: str, tool_name: str, task_id: str | None = None) -> bool:
|
||||
"""Return True if agent may call tool (optionally for this task)."""
|
||||
if (agent_id, tool_name) in self._deny:
|
||||
return False
|
||||
if task_id and task_id in self._task_tools:
|
||||
return tool_name in self._task_tools[task_id]
|
||||
return True
|
||||
29
fusionagi/governance/audit_log.py
Normal file
29
fusionagi/governance/audit_log.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Structured audit log for AGI."""
|
||||
from typing import Any
|
||||
from fusionagi.schemas.audit import AuditEntry, AuditEventType
|
||||
from fusionagi._logger import logger
|
||||
import uuid
|
||||
|
||||
class AuditLog:
|
||||
def __init__(self, max_entries=100000):
|
||||
self._entries = []
|
||||
self._max_entries = max_entries
|
||||
self._by_task = {}
|
||||
self._by_type = {}
|
||||
def append(self, event_type, actor, action="", task_id=None, payload=None, outcome=""):
|
||||
entry_id = str(uuid.uuid4())
|
||||
entry = AuditEntry(entry_id=entry_id, event_type=event_type, actor=actor, task_id=task_id, action=action, payload=payload or {}, outcome=outcome)
|
||||
if len(self._entries) >= self._max_entries:
|
||||
self._entries.pop(0)
|
||||
idx = len(self._entries)
|
||||
self._entries.append(entry)
|
||||
if entry.task_id:
|
||||
self._by_task.setdefault(entry.task_id, []).append(idx)
|
||||
self._by_type.setdefault(entry.event_type.value, []).append(idx)
|
||||
return entry_id
|
||||
def get_by_task(self, task_id, limit=100):
|
||||
indices = self._by_task.get(task_id, [])[-limit:]
|
||||
return [self._entries[i] for i in indices if i < len(self._entries)]
|
||||
def get_by_type(self, event_type, limit=100):
|
||||
indices = self._by_type.get(event_type.value, [])[-limit:]
|
||||
return [self._entries[i] for i in indices if i < len(self._entries)]
|
||||
71
fusionagi/governance/guardrails.py
Normal file
71
fusionagi/governance/guardrails.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Guardrails: pre/post checks for tool calls (block paths, sanitize inputs)."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class PreCheckResult(BaseModel):
|
||||
"""Result of a guardrails pre-check: allowed, optional sanitized args, optional error message."""
|
||||
|
||||
allowed: bool = Field(..., description="Whether the call is allowed")
|
||||
sanitized_args: dict[str, Any] | None = Field(default=None, description="Args to use if allowed and sanitized")
|
||||
error_message: str | None = Field(default=None, description="Reason for denial if not allowed")
|
||||
|
||||
|
||||
class Guardrails:
|
||||
"""Pre/post checks for tool invocations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._blocked_paths: list[str] = []
|
||||
self._blocked_patterns: list[re.Pattern[str]] = []
|
||||
self._custom_checks: list[Any] = []
|
||||
|
||||
def block_path_prefix(self, prefix: str) -> None:
|
||||
"""Block any file path starting with this prefix."""
|
||||
self._blocked_paths.append(prefix.rstrip("/"))
|
||||
|
||||
def block_path_pattern(self, pattern: str) -> None:
|
||||
"""Block paths matching this regex."""
|
||||
self._blocked_patterns.append(re.compile(pattern))
|
||||
|
||||
def add_check(self, check: Any) -> None:
|
||||
"""
|
||||
Add a custom pre-check. Check receives (tool_name, args); must not mutate caller's args.
|
||||
Returns (allowed, sanitized_args or error_message): (True, dict) or (True, None) or (False, str).
|
||||
Returned sanitized_args are used for subsequent checks and invocation.
|
||||
"""
|
||||
self._custom_checks.append(check)
|
||||
|
||||
def pre_check(self, tool_name: str, args: dict[str, Any]) -> PreCheckResult:
|
||||
"""Run all pre-checks. Returns PreCheckResult (allowed, sanitized_args, error_message)."""
|
||||
args = dict(args) # Copy to avoid mutating caller's args
|
||||
for key in ("path", "file_path"):
|
||||
if key in args and isinstance(args[key], str):
|
||||
path = args[key]
|
||||
for prefix in self._blocked_paths:
|
||||
if path.startswith(prefix) or path.startswith(prefix + "/"):
|
||||
reason = "Blocked path prefix: " + prefix
|
||||
logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason})
|
||||
return PreCheckResult(allowed=False, error_message=reason)
|
||||
for pat in self._blocked_patterns:
|
||||
if pat.search(path):
|
||||
reason = "Blocked path pattern"
|
||||
logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason})
|
||||
return PreCheckResult(allowed=False, error_message=reason)
|
||||
for check in self._custom_checks:
|
||||
allowed, result = check(tool_name, args)
|
||||
if not allowed:
|
||||
reason = result if isinstance(result, str) else "Check failed"
|
||||
logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason})
|
||||
return PreCheckResult(allowed=False, error_message=reason)
|
||||
if isinstance(result, dict):
|
||||
args = result
|
||||
return PreCheckResult(allowed=True, sanitized_args=args)
|
||||
|
||||
def post_check(self, tool_name: str, result: Any) -> tuple[bool, str]:
|
||||
"""Optional post-check; return (True, "") or (False, error_message)."""
|
||||
return True, ""
|
||||
29
fusionagi/governance/intent_alignment.py
Normal file
29
fusionagi/governance/intent_alignment.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Intent alignment: what user meant vs what user said for AGI."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class IntentAlignment:
|
||||
"""
|
||||
Checks that system interpretation of user goal matches user intent.
|
||||
Placeholder: returns True; wire to confirmation or paraphrase flow.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._checks: list[tuple[str, str]] = [] # (interpreted_goal, user_input)
|
||||
|
||||
def check(self, interpreted_goal: str, user_input: str, context: dict[str, Any] | None = None) -> tuple[bool, str]:
|
||||
"""
|
||||
Returns (aligned, message). If not aligned, message suggests clarification.
|
||||
"""
|
||||
if not interpreted_goal or not user_input:
|
||||
return True, ""
|
||||
self._checks.append((interpreted_goal, user_input))
|
||||
logger.debug("IntentAlignment check", extra={"goal": interpreted_goal[:80]})
|
||||
return True, ""
|
||||
|
||||
def suggest_paraphrase(self, goal: str) -> str:
|
||||
"""Return suggested paraphrase for user to confirm."""
|
||||
return f"Just to confirm, you want: {goal}"
|
||||
44
fusionagi/governance/override.py
Normal file
44
fusionagi/governance/override.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Human override hooks: events the orchestrator can fire before high-risk steps."""
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
# Callback: (event_type, payload) -> proceed: bool
|
||||
OverrideCallback = Callable[[str, dict[str, Any]], bool]
|
||||
|
||||
|
||||
class OverrideHooks:
|
||||
"""Optional callbacks for human override; no UI, just interface and logging."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._hooks: list[OverrideCallback] = []
|
||||
self._log: list[dict[str, Any]] = []
|
||||
|
||||
def register(self, callback: OverrideCallback) -> None:
|
||||
"""Register a callback; if any returns False, treat as 'do not proceed'."""
|
||||
self._hooks.append(callback)
|
||||
|
||||
def fire(self, event_type: str, payload: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Fire event (e.g. task_paused_for_approval). If no hooks, return True (proceed).
|
||||
If any hook returns False, return False (do not proceed). Log all events.
|
||||
Exception in a hook implies do not proceed.
|
||||
"""
|
||||
entry = {"event": event_type, "payload": payload}
|
||||
self._log.append(entry)
|
||||
logger.info("Override fire", extra={"event_type": event_type})
|
||||
for h in self._hooks:
|
||||
try:
|
||||
if not h(event_type, payload):
|
||||
logger.info("Override hook returned do not proceed", extra={"event_type": event_type})
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Override hook raised", extra={"event_type": event_type})
|
||||
return False
|
||||
logger.debug("Override fire proceed", extra={"event_type": event_type})
|
||||
return True
|
||||
|
||||
def get_log(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||
"""Return recent override events (for auditing)."""
|
||||
return list(self._log[-limit:])
|
||||
73
fusionagi/governance/policy_engine.py
Normal file
73
fusionagi/governance/policy_engine.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Policy engine: hard constraints independent of LLM for AGI."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.policy import PolicyEffect, PolicyRule
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
"""Evaluates policy rules; higher priority first; first match wins (allow/deny)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._rules: list[PolicyRule] = []
|
||||
|
||||
def add_rule(self, rule: PolicyRule) -> None:
|
||||
self._rules.append(rule)
|
||||
self._rules.sort(key=lambda r: -r.priority)
|
||||
logger.debug("PolicyEngine: rule added", extra={"rule_id": rule.rule_id})
|
||||
|
||||
def get_rules(self) -> list[PolicyRule]:
|
||||
"""Return all rules (copy)."""
|
||||
return list(self._rules)
|
||||
|
||||
def get_rule(self, rule_id: str) -> PolicyRule | None:
|
||||
"""Return rule by id or None."""
|
||||
for r in self._rules:
|
||||
if r.rule_id == rule_id:
|
||||
return r
|
||||
return None
|
||||
|
||||
def update_rule(self, rule_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update an existing rule by id. Updates can include condition, effect, reason, priority.
|
||||
Returns True if updated, False if rule_id not found.
|
||||
"""
|
||||
for i, r in enumerate(self._rules):
|
||||
if r.rule_id == rule_id:
|
||||
allowed = {"condition", "effect", "reason", "priority"}
|
||||
data = r.model_dump()
|
||||
for k, v in updates.items():
|
||||
if k in allowed:
|
||||
data[k] = v
|
||||
self._rules[i] = PolicyRule.model_validate(data)
|
||||
self._rules.sort(key=lambda x: -x.priority)
|
||||
logger.debug("PolicyEngine: rule updated", extra={"rule_id": rule_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_rule(self, rule_id: str) -> bool:
|
||||
"""Remove a rule by id. Returns True if removed."""
|
||||
for i, r in enumerate(self._rules):
|
||||
if r.rule_id == rule_id:
|
||||
self._rules.pop(i)
|
||||
logger.debug("PolicyEngine: rule removed", extra={"rule_id": rule_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def check(self, action: str, context: dict[str, Any]) -> tuple[bool, str]:
|
||||
"""
|
||||
Returns (allowed, reason). Context has e.g. tool_name, domain, data_class, agent_id.
|
||||
"""
|
||||
for rule in self._rules:
|
||||
if self._match(rule.condition, context):
|
||||
if rule.effect == PolicyEffect.DENY:
|
||||
return False, rule.reason or "Policy denied"
|
||||
return True, rule.reason or "Policy allowed"
|
||||
return True, ""
|
||||
|
||||
def _match(self, condition: dict[str, Any], context: dict[str, Any]) -> bool:
|
||||
for k, v in condition.items():
|
||||
if context.get(k) != v:
|
||||
return False
|
||||
return True
|
||||
38
fusionagi/governance/rate_limiter.py
Normal file
38
fusionagi/governance/rate_limiter.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Rate limiting: per agent or per tool; reject or queue if exceeded.
|
||||
|
||||
Optional; not wired to Executor or Orchestrator by default. Wire by calling
|
||||
allow(key) before tool invocation or message routing and checking the result.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple in-memory rate limiter: max N calls per window_seconds per key."""
|
||||
|
||||
def __init__(self, max_calls: int = 60, window_seconds: float = 60.0) -> None:
|
||||
self._max_calls = max_calls
|
||||
self._window = window_seconds
|
||||
self._calls: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def allow(self, key: str) -> tuple[bool, str]:
|
||||
"""Record a call for key; return (True, "") or (False, reason)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._window
|
||||
self._calls[key] = [t for t in self._calls[key] if t > cutoff]
|
||||
if len(self._calls[key]) >= self._max_calls:
|
||||
reason = f"Rate limit exceeded for {key}"
|
||||
logger.info("Rate limiter rejected", extra={"key": key, "reason": reason})
|
||||
return False, reason
|
||||
self._calls[key].append(now)
|
||||
return True, ""
|
||||
|
||||
def reset(self, key: str | None = None) -> None:
|
||||
"""Reset counts for key or all."""
|
||||
if key is None:
|
||||
self._calls.clear()
|
||||
else:
|
||||
self._calls.pop(key, None)
|
||||
132
fusionagi/governance/safety_pipeline.py
Normal file
132
fusionagi/governance/safety_pipeline.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Safety pipeline: pre-check (input moderation), post-check (output scan)."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.governance.guardrails import Guardrails, PreCheckResult
|
||||
from fusionagi.schemas.audit import AuditEventType
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModerationResult:
|
||||
"""Result of input moderation."""
|
||||
|
||||
allowed: bool
|
||||
transformed: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class InputModerator:
|
||||
"""Pre-check: block or transform user input before processing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._blocked_patterns: list[re.Pattern[str]] = []
|
||||
self._blocked_phrases: list[str] = []
|
||||
|
||||
def add_blocked_pattern(self, pattern: str) -> None:
|
||||
"""Add regex pattern to block (e.g. prompt injection attempts)."""
|
||||
self._blocked_patterns.append(re.compile(pattern, re.I))
|
||||
|
||||
def add_blocked_phrase(self, phrase: str) -> None:
|
||||
"""Add exact phrase to block."""
|
||||
self._blocked_phrases.append(phrase.lower())
|
||||
|
||||
def moderate(self, text: str) -> ModerationResult:
|
||||
"""Check input; return allowed/denied and optional transformed text."""
|
||||
if not text or not text.strip():
|
||||
return ModerationResult(allowed=False, reason="Empty input")
|
||||
lowered = text.lower()
|
||||
for phrase in self._blocked_phrases:
|
||||
if phrase in lowered:
|
||||
logger.info("Input blocked: blocked phrase", extra={"phrase": phrase[:50]})
|
||||
return ModerationResult(allowed=False, reason=f"Blocked phrase: {phrase[:30]}...")
|
||||
for pat in self._blocked_patterns:
|
||||
if pat.search(text):
|
||||
logger.info("Input blocked: pattern match", extra={"pattern": pat.pattern[:50]})
|
||||
return ModerationResult(allowed=False, reason="Input matched blocked pattern")
|
||||
return ModerationResult(allowed=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputScanResult:
|
||||
"""Result of output (final answer) scan."""
|
||||
|
||||
passed: bool
|
||||
flags: list[str]
|
||||
sanitized: str | None = None
|
||||
|
||||
|
||||
class OutputScanner:
|
||||
"""Post-check: scan final answer for policy violations, PII leakage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pii_patterns: list[tuple[str, re.Pattern[str]]] = [
|
||||
("ssn", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")),
|
||||
("credit_card", re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b")),
|
||||
]
|
||||
self._blocked_patterns: list[re.Pattern[str]] = []
|
||||
|
||||
def add_pii_pattern(self, name: str, pattern: str) -> None:
|
||||
"""Add PII detection pattern."""
|
||||
self._pii_patterns.append((name, re.compile(pattern)))
|
||||
|
||||
def add_blocked_pattern(self, pattern: str) -> None:
|
||||
"""Add pattern that fails the output."""
|
||||
self._blocked_patterns.append(re.compile(pattern, re.I))
|
||||
|
||||
def scan(self, text: str) -> OutputScanResult:
|
||||
"""Scan output; return passed, flags, optional sanitized."""
|
||||
flags: list[str] = []
|
||||
for name, pat in self._pii_patterns:
|
||||
if pat.search(text):
|
||||
flags.append(f"potential_pii:{name}")
|
||||
for pat in self._blocked_patterns:
|
||||
if pat.search(text):
|
||||
flags.append("blocked_content_detected")
|
||||
if flags:
|
||||
return OutputScanResult(passed=False, flags=flags)
|
||||
return OutputScanResult(passed=True, flags=[])
|
||||
|
||||
|
||||
class SafetyPipeline:
|
||||
"""Combined pre/post safety checks for Dvādaśa."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moderator: InputModerator | None = None,
|
||||
scanner: OutputScanner | None = None,
|
||||
guardrails: Guardrails | None = None,
|
||||
audit_log: Any | None = None,
|
||||
) -> None:
|
||||
self._moderator = moderator or InputModerator()
|
||||
self._scanner = scanner or OutputScanner()
|
||||
self._guardrails = guardrails or Guardrails()
|
||||
self._audit = audit_log
|
||||
|
||||
def pre_check(self, user_input: str) -> ModerationResult:
|
||||
"""Run input moderation."""
|
||||
result = self._moderator.moderate(user_input)
|
||||
if self._audit and not result.allowed:
|
||||
self._audit.append(
|
||||
AuditEventType.POLICY_CHECK,
|
||||
actor="safety_pipeline",
|
||||
action="input_moderation",
|
||||
payload={"reason": result.reason},
|
||||
outcome="denied",
|
||||
)
|
||||
return result
|
||||
|
||||
def post_check(self, final_answer: str) -> OutputScanResult:
|
||||
"""Run output scan."""
|
||||
result = self._scanner.scan(final_answer)
|
||||
if self._audit and not result.passed:
|
||||
self._audit.append(
|
||||
AuditEventType.POLICY_CHECK,
|
||||
actor="safety_pipeline",
|
||||
action="output_scan",
|
||||
payload={"flags": result.flags},
|
||||
outcome="flagged",
|
||||
)
|
||||
return result
|
||||
30
fusionagi/interfaces/__init__.py
Normal file
30
fusionagi/interfaces/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Multi-modal interface layer for FusionAGI.
|
||||
|
||||
Provides admin control panel, user interfaces, and sensory interaction adapters.
|
||||
"""
|
||||
|
||||
from fusionagi.interfaces.base import (
|
||||
InterfaceAdapter,
|
||||
InterfaceCapabilities,
|
||||
InterfaceMessage,
|
||||
ModalityType,
|
||||
)
|
||||
from fusionagi.interfaces.voice import VoiceInterface, VoiceLibrary, TTSAdapter, STTAdapter
|
||||
from fusionagi.interfaces.conversation import ConversationManager, ConversationTuner
|
||||
from fusionagi.interfaces.admin_panel import AdminControlPanel
|
||||
from fusionagi.interfaces.multimodal_ui import MultiModalUI
|
||||
|
||||
__all__ = [
|
||||
"InterfaceAdapter",
|
||||
"InterfaceCapabilities",
|
||||
"InterfaceMessage",
|
||||
"ModalityType",
|
||||
"VoiceInterface",
|
||||
"VoiceLibrary",
|
||||
"TTSAdapter",
|
||||
"STTAdapter",
|
||||
"ConversationManager",
|
||||
"ConversationTuner",
|
||||
"AdminControlPanel",
|
||||
"MultiModalUI",
|
||||
]
|
||||
425
fusionagi/interfaces/admin_panel.py
Normal file
425
fusionagi/interfaces/admin_panel.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Admin control panel for FusionAGI system management.
|
||||
|
||||
Provides administrative interface for:
|
||||
- Voice library management
|
||||
- Conversation tuning
|
||||
- Agent configuration
|
||||
- System monitoring
|
||||
- Governance policies
|
||||
- Manufacturing authority
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now, utc_now_iso
|
||||
from fusionagi.interfaces.voice import VoiceLibrary, VoiceProfile
|
||||
from fusionagi.interfaces.conversation import ConversationTuner, ConversationStyle
|
||||
from fusionagi.core import Orchestrator, EventBus, StateManager
|
||||
from fusionagi.governance import PolicyEngine, AuditLog
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
"""System status information."""
|
||||
|
||||
status: Literal["healthy", "degraded", "offline"] = Field(description="Overall system status")
|
||||
uptime_seconds: float = Field(description="System uptime in seconds")
|
||||
active_tasks: int = Field(description="Number of active tasks")
|
||||
active_agents: int = Field(description="Number of registered agents")
|
||||
active_sessions: int = Field(description="Number of active user sessions")
|
||||
memory_usage_mb: float | None = Field(default=None, description="Memory usage in MB")
|
||||
cpu_usage_percent: float | None = Field(default=None, description="CPU usage percentage")
|
||||
timestamp: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for an agent."""
|
||||
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
enabled: bool = Field(default=True)
|
||||
max_concurrent_tasks: int = Field(default=10)
|
||||
timeout_seconds: float = Field(default=300.0)
|
||||
retry_policy: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AdminControlPanel:
|
||||
"""
|
||||
Administrative control panel for FusionAGI.
|
||||
|
||||
Provides centralized management interface for:
|
||||
- Voice libraries and TTS/STT configuration
|
||||
- Conversation styles and natural language tuning
|
||||
- Agent configuration and monitoring
|
||||
- System health and performance metrics
|
||||
- Governance policies and audit logs
|
||||
- Manufacturing authority (MAA) settings
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: Orchestrator,
|
||||
event_bus: EventBus,
|
||||
state_manager: StateManager,
|
||||
voice_library: VoiceLibrary | None = None,
|
||||
conversation_tuner: ConversationTuner | None = None,
|
||||
policy_engine: PolicyEngine | None = None,
|
||||
audit_log: AuditLog | None = None,
|
||||
session_count_callback: Callable[[], int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize admin control panel.
|
||||
|
||||
Args:
|
||||
orchestrator: FusionAGI orchestrator.
|
||||
event_bus: Event bus for system events (use EventBus(history_size=N) for event history).
|
||||
state_manager: State manager for task state.
|
||||
voice_library: Voice library for TTS management.
|
||||
conversation_tuner: Conversation tuner for NL configuration.
|
||||
policy_engine: Policy engine for governance.
|
||||
audit_log: Audit log for compliance tracking.
|
||||
session_count_callback: Optional callback returning active user session count (e.g. from MultiModalUI).
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.event_bus = event_bus
|
||||
self.state_manager = state_manager
|
||||
self.voice_library = voice_library or VoiceLibrary()
|
||||
self.conversation_tuner = conversation_tuner or ConversationTuner()
|
||||
self.policy_engine = policy_engine
|
||||
self.audit_log = audit_log
|
||||
self._session_count_callback = session_count_callback
|
||||
|
||||
self._agent_configs: dict[str, AgentConfig] = {}
|
||||
self._start_time = utc_now()
|
||||
|
||||
logger.info("AdminControlPanel initialized")
|
||||
|
||||
# ========== Voice Management ==========
|
||||
|
||||
def add_voice_profile(self, profile: VoiceProfile) -> str:
|
||||
"""
|
||||
Add a voice profile to the library.
|
||||
|
||||
Args:
|
||||
profile: Voice profile to add.
|
||||
|
||||
Returns:
|
||||
Voice ID.
|
||||
"""
|
||||
voice_id = self.voice_library.add_voice(profile)
|
||||
self._log_admin_action("voice_added", {"voice_id": voice_id, "name": profile.name})
|
||||
return voice_id
|
||||
|
||||
def list_voices(
|
||||
self,
|
||||
language: str | None = None,
|
||||
gender: str | None = None,
|
||||
style: str | None = None,
|
||||
) -> list[VoiceProfile]:
|
||||
"""List voice profiles with optional filtering."""
|
||||
return self.voice_library.list_voices(language=language, gender=gender, style=style)
|
||||
|
||||
def update_voice_profile(self, voice_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a voice profile.
|
||||
|
||||
Args:
|
||||
voice_id: Voice ID to update.
|
||||
updates: Dictionary of fields to update.
|
||||
|
||||
Returns:
|
||||
True if updated, False if not found.
|
||||
"""
|
||||
success = self.voice_library.update_voice(voice_id, updates)
|
||||
if success:
|
||||
self._log_admin_action("voice_updated", {"voice_id": voice_id, "fields": list(updates.keys())})
|
||||
return success
|
||||
|
||||
def remove_voice_profile(self, voice_id: str) -> bool:
|
||||
"""Remove a voice profile."""
|
||||
success = self.voice_library.remove_voice(voice_id)
|
||||
if success:
|
||||
self._log_admin_action("voice_removed", {"voice_id": voice_id})
|
||||
return success
|
||||
|
||||
def set_default_voice(self, voice_id: str) -> bool:
|
||||
"""Set the default voice."""
|
||||
success = self.voice_library.set_default_voice(voice_id)
|
||||
if success:
|
||||
self._log_admin_action("default_voice_set", {"voice_id": voice_id})
|
||||
return success
|
||||
|
||||
# ========== Conversation Tuning ==========
|
||||
|
||||
def register_conversation_style(self, name: str, style: ConversationStyle) -> None:
|
||||
"""
|
||||
Register a conversation style.
|
||||
|
||||
Args:
|
||||
name: Style name.
|
||||
style: Conversation style configuration.
|
||||
"""
|
||||
self.conversation_tuner.register_style(name, style)
|
||||
self._log_admin_action("conversation_style_registered", {"name": name})
|
||||
|
||||
def list_conversation_styles(self) -> list[str]:
|
||||
"""List all registered conversation style names."""
|
||||
return self.conversation_tuner.list_styles()
|
||||
|
||||
def get_conversation_style(self, name: str) -> ConversationStyle | None:
|
||||
"""Get a conversation style by name."""
|
||||
return self.conversation_tuner.get_style(name)
|
||||
|
||||
def set_default_conversation_style(self, style: ConversationStyle) -> None:
|
||||
"""Set the default conversation style."""
|
||||
self.conversation_tuner.set_default_style(style)
|
||||
self._log_admin_action("default_conversation_style_set", {})
|
||||
|
||||
# ========== Agent Management ==========
|
||||
|
||||
def configure_agent(self, config: AgentConfig) -> None:
|
||||
"""
|
||||
Configure an agent.
|
||||
|
||||
Args:
|
||||
config: Agent configuration.
|
||||
"""
|
||||
self._agent_configs[config.agent_id] = config
|
||||
self._log_admin_action("agent_configured", {"agent_id": config.agent_id})
|
||||
logger.info("Agent configured", extra={"agent_id": config.agent_id})
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> AgentConfig | None:
|
||||
"""Get agent configuration."""
|
||||
return self._agent_configs.get(agent_id)
|
||||
|
||||
def list_agents(self) -> list[str]:
|
||||
"""List all registered agent IDs."""
|
||||
return list(self.orchestrator._agents.keys())
|
||||
|
||||
def enable_agent(self, agent_id: str) -> bool:
|
||||
"""Enable an agent."""
|
||||
config = self._agent_configs.get(agent_id)
|
||||
if config:
|
||||
config.enabled = True
|
||||
self._log_admin_action("agent_enabled", {"agent_id": agent_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_agent(self, agent_id: str) -> bool:
|
||||
"""Disable an agent."""
|
||||
config = self._agent_configs.get(agent_id)
|
||||
if config:
|
||||
config.enabled = False
|
||||
self._log_admin_action("agent_disabled", {"agent_id": agent_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
# ========== System Monitoring ==========
|
||||
|
||||
def get_system_status(self) -> SystemStatus:
|
||||
"""
|
||||
Get current system status.
|
||||
|
||||
Returns:
|
||||
System status information.
|
||||
"""
|
||||
uptime = (utc_now() - self._start_time).total_seconds()
|
||||
|
||||
# Count active tasks
|
||||
active_tasks = 0
|
||||
failed_count = 0
|
||||
for task_id in self.state_manager._tasks.keys():
|
||||
task = self.state_manager.get_task(task_id)
|
||||
if task:
|
||||
if task.state.value in ("pending", "active"):
|
||||
active_tasks += 1
|
||||
elif task.state.value == "failed":
|
||||
failed_count += 1
|
||||
|
||||
active_agents = len(self.orchestrator._agents)
|
||||
active_sessions = self._session_count_callback() if self._session_count_callback else 0
|
||||
|
||||
# Health: healthy under normal load; degraded if high task count or many failures
|
||||
if active_tasks > 1000 or (failed_count > 50 and active_tasks > 100):
|
||||
status: Literal["healthy", "degraded", "offline"] = "degraded"
|
||||
else:
|
||||
status = "healthy"
|
||||
|
||||
return SystemStatus(
|
||||
status=status,
|
||||
uptime_seconds=uptime,
|
||||
active_tasks=active_tasks,
|
||||
active_agents=active_agents,
|
||||
active_sessions=active_sessions,
|
||||
)
|
||||
|
||||
def get_task_statistics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get task execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with task statistics.
|
||||
"""
|
||||
stats = {
|
||||
"total_tasks": len(self.state_manager._tasks),
|
||||
"by_state": {},
|
||||
"by_priority": {},
|
||||
}
|
||||
|
||||
for task_id in self.state_manager._tasks.keys():
|
||||
task = self.state_manager.get_task(task_id)
|
||||
if task:
|
||||
# Count by state
|
||||
state_key = task.state.value
|
||||
stats["by_state"][state_key] = stats["by_state"].get(state_key, 0) + 1
|
||||
|
||||
# Count by priority
|
||||
priority_key = task.priority.value
|
||||
stats["by_priority"][priority_key] = stats["by_priority"].get(priority_key, 0) + 1
|
||||
|
||||
return stats
|
||||
|
||||
def get_recent_events(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get recent system events from the event bus.
|
||||
|
||||
Requires EventBus(history_size=N) at construction for non-empty results.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of events to return.
|
||||
|
||||
Returns:
|
||||
List of recent events (event_type, payload, timestamp).
|
||||
"""
|
||||
if hasattr(self.event_bus, "get_recent_events"):
|
||||
return self.event_bus.get_recent_events(limit=limit)
|
||||
return []
|
||||
|
||||
# ========== Governance & Audit ==========
|
||||
|
||||
def get_audit_entries(
|
||||
self,
|
||||
limit: int = 100,
|
||||
action_type: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get audit log entries.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return.
|
||||
action_type: Optional filter by action type.
|
||||
|
||||
Returns:
|
||||
List of audit entries.
|
||||
"""
|
||||
if not self.audit_log:
|
||||
return []
|
||||
|
||||
entries = self.audit_log.query(limit=limit)
|
||||
|
||||
if action_type:
|
||||
entries = [e for e in entries if e.get("action") == action_type]
|
||||
|
||||
return entries
|
||||
|
||||
def update_policy(self, policy_id: str, policy_data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a governance policy.
|
||||
|
||||
Args:
|
||||
policy_id: Policy identifier.
|
||||
policy_data: Policy configuration.
|
||||
|
||||
Returns:
|
||||
True if updated, False if policy engine not available.
|
||||
"""
|
||||
if not self.policy_engine:
|
||||
return False
|
||||
|
||||
rule_id = policy_data.get("rule_id", policy_id)
|
||||
if self.policy_engine.get_rule(rule_id) is None:
|
||||
return False
|
||||
updates = {k: v for k, v in policy_data.items() if k in ("condition", "effect", "reason", "priority")}
|
||||
ok = self.policy_engine.update_rule(rule_id, updates)
|
||||
if ok:
|
||||
self._log_admin_action("policy_updated", {"policy_id": policy_id, "rule_id": rule_id})
|
||||
return ok
|
||||
|
||||
# ========== Utility Methods ==========
|
||||
|
||||
def _log_admin_action(self, action: str, details: dict[str, Any]) -> None:
|
||||
"""
|
||||
Log an administrative action.
|
||||
|
||||
Args:
|
||||
action: Action type.
|
||||
details: Action details.
|
||||
"""
|
||||
logger.info(f"Admin action: {action}", extra=details)
|
||||
|
||||
if self.audit_log:
|
||||
self.audit_log.log(
|
||||
action=action,
|
||||
actor="admin",
|
||||
details=details,
|
||||
timestamp=utc_now_iso(),
|
||||
)
|
||||
|
||||
def export_configuration(self) -> dict[str, Any]:
|
||||
"""
|
||||
Export system configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with full system configuration.
|
||||
"""
|
||||
return {
|
||||
"voices": [v.model_dump() for v in self.voice_library.list_voices()],
|
||||
"conversation_styles": {
|
||||
name: self.conversation_tuner.get_style(name).model_dump()
|
||||
for name in self.conversation_tuner.list_styles()
|
||||
},
|
||||
"agent_configs": {
|
||||
agent_id: config.model_dump()
|
||||
for agent_id, config in self._agent_configs.items()
|
||||
},
|
||||
"exported_at": utc_now_iso(),
|
||||
}
|
||||
|
||||
def import_configuration(self, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Import system configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary to import.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Import voices
|
||||
if "voices" in config:
|
||||
for voice_data in config["voices"]:
|
||||
profile = VoiceProfile(**voice_data)
|
||||
self.voice_library.add_voice(profile)
|
||||
|
||||
# Import conversation styles
|
||||
if "conversation_styles" in config:
|
||||
for name, style_data in config["conversation_styles"].items():
|
||||
style = ConversationStyle(**style_data)
|
||||
self.conversation_tuner.register_style(name, style)
|
||||
|
||||
# Import agent configs
|
||||
if "agent_configs" in config:
|
||||
for agent_id, config_data in config["agent_configs"].items():
|
||||
agent_config = AgentConfig(**config_data)
|
||||
self._agent_configs[agent_id] = agent_config
|
||||
|
||||
self._log_admin_action("configuration_imported", {"source": "file"})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Configuration import failed", extra={"error": str(e)})
|
||||
return False
|
||||
121
fusionagi/interfaces/base.py
Normal file
121
fusionagi/interfaces/base.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Base interface adapter for multi-modal interaction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now_iso
|
||||
|
||||
|
||||
class ModalityType(str, Enum):
|
||||
"""Types of sensory modalities supported."""
|
||||
|
||||
TEXT = "text"
|
||||
VOICE = "voice"
|
||||
VISUAL = "visual"
|
||||
HAPTIC = "haptic"
|
||||
GESTURE = "gesture"
|
||||
BIOMETRIC = "biometric"
|
||||
|
||||
|
||||
class InterfaceMessage(BaseModel):
|
||||
"""Message exchanged through an interface."""
|
||||
|
||||
id: str = Field(description="Unique message identifier")
|
||||
modality: ModalityType = Field(description="Sensory modality of this message")
|
||||
content: Any = Field(description="Message content (modality-specific)")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
timestamp: str = Field(
|
||||
default_factory=utc_now_iso,
|
||||
description="Message timestamp"
|
||||
)
|
||||
user_id: str | None = Field(default=None, description="User identifier if applicable")
|
||||
session_id: str | None = Field(default=None, description="Session identifier")
|
||||
|
||||
|
||||
class InterfaceCapabilities(BaseModel):
|
||||
"""Capabilities of an interface adapter."""
|
||||
|
||||
supported_modalities: list[ModalityType] = Field(description="Supported sensory modalities")
|
||||
supports_streaming: bool = Field(default=False, description="Supports streaming responses")
|
||||
supports_interruption: bool = Field(default=False, description="Supports mid-response interruption")
|
||||
supports_multimodal: bool = Field(default=False, description="Supports multiple modalities simultaneously")
|
||||
latency_ms: float | None = Field(default=None, description="Expected latency in milliseconds")
|
||||
max_concurrent_sessions: int | None = Field(default=None, description="Max concurrent sessions")
|
||||
|
||||
|
||||
class InterfaceAdapter(ABC):
|
||||
"""
|
||||
Abstract base for interface adapters.
|
||||
|
||||
Interface adapters translate between human sensory modalities and FusionAGI's
|
||||
internal message format. Each adapter handles one or more modalities (voice,
|
||||
visual, haptic, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
@abstractmethod
|
||||
def capabilities(self) -> InterfaceCapabilities:
|
||||
"""Return the capabilities of this interface."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, message: InterfaceMessage) -> None:
|
||||
"""
|
||||
Send a message through this interface to the user.
|
||||
|
||||
Args:
|
||||
message: Message to send (modality-specific content).
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self, timeout_seconds: float | None = None) -> InterfaceMessage | None:
|
||||
"""
|
||||
Receive a message from the user through this interface.
|
||||
|
||||
Args:
|
||||
timeout_seconds: Optional timeout for receiving.
|
||||
|
||||
Returns:
|
||||
Received message or None if timeout.
|
||||
"""
|
||||
...
|
||||
|
||||
async def stream_send(self, messages: AsyncIterator[InterfaceMessage]) -> None:
|
||||
"""
|
||||
Stream messages to the user (for streaming responses).
|
||||
|
||||
Default implementation sends each message individually. Override for
|
||||
true streaming support.
|
||||
|
||||
Args:
|
||||
messages: Async iterator of messages to stream.
|
||||
"""
|
||||
async for msg in messages:
|
||||
await self.send(msg)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the interface (connect, authenticate, etc.)."""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the interface gracefully."""
|
||||
pass
|
||||
|
||||
def validate_message(self, message: InterfaceMessage) -> bool:
|
||||
"""
|
||||
Validate that a message is compatible with this interface.
|
||||
|
||||
Args:
|
||||
message: Message to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
caps = self.capabilities()
|
||||
return message.modality in caps.supported_modalities
|
||||
392
fusionagi/interfaces/conversation.py
Normal file
392
fusionagi/interfaces/conversation.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Conversation management and natural language tuning."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now_iso
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class ConversationStyle(BaseModel):
|
||||
"""Configuration for conversation style and personality."""
|
||||
|
||||
formality: Literal["casual", "neutral", "formal"] = Field(
|
||||
default="neutral",
|
||||
description="Conversation formality level"
|
||||
)
|
||||
verbosity: Literal["concise", "balanced", "detailed"] = Field(
|
||||
default="balanced",
|
||||
description="Response length preference"
|
||||
)
|
||||
personality_traits: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Personality traits (e.g., friendly, professional, humorous)"
|
||||
)
|
||||
empathy_level: float = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotional responsiveness (0=robotic, 1=highly empathetic)"
|
||||
)
|
||||
proactivity: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Tendency to offer suggestions (0=reactive, 1=proactive)"
|
||||
)
|
||||
humor_level: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Use of humor (0=serious, 1=playful)"
|
||||
)
|
||||
technical_depth: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Technical detail level (0=simple, 1=expert)"
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
"""Context for a conversation session."""
|
||||
|
||||
session_id: str = Field(default_factory=lambda: f"session_{uuid.uuid4().hex}")
|
||||
user_id: str | None = Field(default=None)
|
||||
style: ConversationStyle = Field(default_factory=ConversationStyle)
|
||||
language: str = Field(default="en", description="Primary language code")
|
||||
domain: str | None = Field(default=None, description="Domain/topic of conversation")
|
||||
history_length: int = Field(default=10, description="Number of turns to maintain in context")
|
||||
started_at: str = Field(default_factory=utc_now_iso)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ConversationTurn(BaseModel):
|
||||
"""A single turn in a conversation."""
|
||||
|
||||
turn_id: str = Field(default_factory=lambda: f"turn_{uuid.uuid4().hex[:8]}")
|
||||
session_id: str
|
||||
speaker: Literal["user", "agent", "system"]
|
||||
content: str
|
||||
intent: str | None = Field(default=None, description="Detected intent")
|
||||
sentiment: float | None = Field(
|
||||
default=None,
|
||||
ge=-1.0,
|
||||
le=1.0,
|
||||
description="Sentiment score (-1=negative, 0=neutral, 1=positive)"
|
||||
)
|
||||
confidence: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
timestamp: str = Field(default_factory=utc_now_iso)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ConversationTuner:
|
||||
"""
|
||||
Conversation tuner for natural language interaction.
|
||||
|
||||
Allows admin to configure conversation style, personality, and behavior
|
||||
for different contexts, users, or agents.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._styles: dict[str, ConversationStyle] = {}
|
||||
self._default_style = ConversationStyle()
|
||||
logger.info("ConversationTuner initialized")
|
||||
|
||||
def register_style(self, name: str, style: ConversationStyle) -> None:
|
||||
"""
|
||||
Register a named conversation style.
|
||||
|
||||
Args:
|
||||
name: Style name (e.g., "customer_support", "technical_expert").
|
||||
style: Conversation style configuration.
|
||||
"""
|
||||
self._styles[name] = style
|
||||
logger.info("Conversation style registered", extra={"name": name})
|
||||
|
||||
def get_style(self, name: str) -> ConversationStyle | None:
|
||||
"""Get a conversation style by name."""
|
||||
return self._styles.get(name)
|
||||
|
||||
def list_styles(self) -> list[str]:
|
||||
"""List all registered style names."""
|
||||
return list(self._styles.keys())
|
||||
|
||||
def set_default_style(self, style: ConversationStyle) -> None:
|
||||
"""Set the default conversation style."""
|
||||
self._default_style = style
|
||||
logger.info("Default conversation style updated")
|
||||
|
||||
def get_default_style(self) -> ConversationStyle:
|
||||
"""Get the default conversation style."""
|
||||
return self._default_style
|
||||
|
||||
def tune_for_context(
|
||||
self,
|
||||
base_style: ConversationStyle | None = None,
|
||||
domain: str | None = None,
|
||||
user_preferences: dict[str, Any] | None = None,
|
||||
) -> ConversationStyle:
|
||||
"""
|
||||
Tune conversation style for a specific context.
|
||||
|
||||
Args:
|
||||
base_style: Base style to start from (uses default if None).
|
||||
domain: Domain/topic to optimize for.
|
||||
user_preferences: User-specific preferences to apply.
|
||||
|
||||
Returns:
|
||||
Tuned conversation style.
|
||||
"""
|
||||
style = base_style or self._default_style.model_copy(deep=True)
|
||||
|
||||
# Apply domain-specific tuning
|
||||
if domain:
|
||||
style = self._apply_domain_tuning(style, domain)
|
||||
|
||||
# Apply user preferences
|
||||
if user_preferences:
|
||||
for key, value in user_preferences.items():
|
||||
if hasattr(style, key):
|
||||
setattr(style, key, value)
|
||||
|
||||
logger.info(
|
||||
"Conversation style tuned",
|
||||
extra={"domain": domain, "has_user_prefs": bool(user_preferences)}
|
||||
)
|
||||
return style
|
||||
|
||||
def _apply_domain_tuning(self, style: ConversationStyle, domain: str) -> ConversationStyle:
|
||||
"""
|
||||
Apply domain-specific tuning to a conversation style.
|
||||
|
||||
Args:
|
||||
style: Base conversation style.
|
||||
domain: Domain to tune for.
|
||||
|
||||
Returns:
|
||||
Tuned conversation style.
|
||||
"""
|
||||
# Domain-specific presets
|
||||
domain_presets = {
|
||||
"technical": {
|
||||
"formality": "formal",
|
||||
"technical_depth": 0.9,
|
||||
"verbosity": "detailed",
|
||||
"humor_level": 0.1,
|
||||
},
|
||||
"customer_support": {
|
||||
"formality": "neutral",
|
||||
"empathy_level": 0.9,
|
||||
"proactivity": 0.8,
|
||||
"verbosity": "balanced",
|
||||
},
|
||||
"casual_chat": {
|
||||
"formality": "casual",
|
||||
"humor_level": 0.7,
|
||||
"empathy_level": 0.8,
|
||||
"technical_depth": 0.3,
|
||||
},
|
||||
"education": {
|
||||
"formality": "neutral",
|
||||
"verbosity": "detailed",
|
||||
"technical_depth": 0.6,
|
||||
"proactivity": 0.7,
|
||||
},
|
||||
}
|
||||
|
||||
preset = domain_presets.get(domain.lower())
|
||||
if preset:
|
||||
for key, value in preset.items():
|
||||
setattr(style, key, value)
|
||||
|
||||
return style
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Conversation manager for maintaining conversation state and history.
|
||||
|
||||
Manages conversation sessions, tracks turns, and provides context for
|
||||
natural language understanding and generation.
|
||||
"""
|
||||
|
||||
def __init__(self, tuner: ConversationTuner | None = None) -> None:
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
tuner: Conversation tuner for style management.
|
||||
"""
|
||||
self.tuner = tuner or ConversationTuner()
|
||||
self._sessions: dict[str, ConversationContext] = {}
|
||||
self._history: dict[str, list[ConversationTurn]] = {}
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
style_name: str | None = None,
|
||||
language: str = "en",
|
||||
domain: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new conversation session.
|
||||
|
||||
Args:
|
||||
user_id: Optional user identifier.
|
||||
style_name: Optional style name (uses default if None).
|
||||
language: Primary language code.
|
||||
domain: Domain/topic of conversation.
|
||||
|
||||
Returns:
|
||||
Session ID.
|
||||
"""
|
||||
style = self.tuner.get_style(style_name) if style_name else self.tuner.get_default_style()
|
||||
|
||||
context = ConversationContext(
|
||||
user_id=user_id,
|
||||
style=style,
|
||||
language=language,
|
||||
domain=domain,
|
||||
)
|
||||
|
||||
self._sessions[context.session_id] = context
|
||||
self._history[context.session_id] = []
|
||||
|
||||
logger.info(
|
||||
"Conversation session created",
|
||||
extra={
|
||||
"session_id": context.session_id,
|
||||
"user_id": user_id,
|
||||
"domain": domain,
|
||||
}
|
||||
)
|
||||
return context.session_id
|
||||
|
||||
def get_session(self, session_id: str) -> ConversationContext | None:
|
||||
"""Get conversation context for a session."""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def add_turn(self, turn: ConversationTurn) -> None:
|
||||
"""
|
||||
Add a turn to conversation history.
|
||||
|
||||
Args:
|
||||
turn: Conversation turn to add.
|
||||
"""
|
||||
if turn.session_id not in self._history:
|
||||
logger.warning("Session not found", extra={"session_id": turn.session_id})
|
||||
return
|
||||
|
||||
history = self._history[turn.session_id]
|
||||
history.append(turn)
|
||||
|
||||
# Trim history to configured length
|
||||
context = self._sessions.get(turn.session_id)
|
||||
if context and len(history) > context.history_length:
|
||||
self._history[turn.session_id] = history[-context.history_length:]
|
||||
|
||||
logger.debug(
|
||||
"Turn added",
|
||||
extra={
|
||||
"session_id": turn.session_id,
|
||||
"speaker": turn.speaker,
|
||||
"content_length": len(turn.content),
|
||||
}
|
||||
)
|
||||
|
||||
def get_history(self, session_id: str, limit: int | None = None) -> list[ConversationTurn]:
|
||||
"""
|
||||
Get conversation history for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
limit: Optional limit on number of turns to return.
|
||||
|
||||
Returns:
|
||||
List of conversation turns (most recent last).
|
||||
"""
|
||||
history = self._history.get(session_id, [])
|
||||
if limit:
|
||||
return history[-limit:]
|
||||
return history
|
||||
|
||||
def get_style_for_session(self, session_id: str) -> ConversationStyle | None:
|
||||
"""
|
||||
Get the conversation style for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
|
||||
Returns:
|
||||
Conversation style for the session, or None if session not found.
|
||||
"""
|
||||
context = self._sessions.get(session_id)
|
||||
return context.style if context else None
|
||||
|
||||
def update_style(self, session_id: str, style: ConversationStyle) -> bool:
|
||||
"""
|
||||
Update conversation style for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
style: New conversation style.
|
||||
|
||||
Returns:
|
||||
True if updated, False if session not found.
|
||||
"""
|
||||
context = self._sessions.get(session_id)
|
||||
if context:
|
||||
context.style = style
|
||||
logger.info("Session style updated", extra={"session_id": session_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def end_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
End a conversation session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
|
||||
Returns:
|
||||
True if ended, False if not found.
|
||||
"""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
# Keep history for analytics but could be cleaned up later
|
||||
logger.info("Session ended", extra={"session_id": session_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_context_summary(self, session_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get a summary of conversation context for LLM prompting.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
|
||||
Returns:
|
||||
Dictionary with context summary.
|
||||
"""
|
||||
context = self._sessions.get(session_id)
|
||||
history = self._history.get(session_id, [])
|
||||
|
||||
if not context:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"user_id": context.user_id,
|
||||
"language": context.language,
|
||||
"domain": context.domain,
|
||||
"style": context.style.model_dump(),
|
||||
"turn_count": len(history),
|
||||
"recent_turns": [
|
||||
{"speaker": t.speaker, "content": t.content, "intent": t.intent}
|
||||
for t in history[-5:] # Last 5 turns
|
||||
],
|
||||
}
|
||||
506
fusionagi/interfaces/multimodal_ui.py
Normal file
506
fusionagi/interfaces/multimodal_ui.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""Multi-modal user interface for full sensory experience with FusionAGI.
|
||||
|
||||
Supports:
|
||||
- Text (chat, commands)
|
||||
- Voice (speech input/output)
|
||||
- Visual (images, video, AR/VR)
|
||||
- Haptic (touch feedback)
|
||||
- Gesture (motion control)
|
||||
- Biometric (emotion detection, physiological signals)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Callable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now_iso
|
||||
from fusionagi.interfaces.base import (
|
||||
InterfaceAdapter,
|
||||
InterfaceMessage,
|
||||
ModalityType,
|
||||
)
|
||||
from fusionagi.interfaces.voice import VoiceInterface, VoiceLibrary
|
||||
from fusionagi.interfaces.conversation import ConversationManager, ConversationTurn
|
||||
from fusionagi.core import Orchestrator
|
||||
from fusionagi.schemas import Task, TaskState
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class UserSession(BaseModel):
|
||||
"""User session with multi-modal interface."""
|
||||
|
||||
session_id: str = Field(default_factory=lambda: f"user_session_{uuid.uuid4().hex}")
|
||||
user_id: str | None = Field(default=None)
|
||||
conversation_session_id: str | None = Field(default=None)
|
||||
active_modalities: list[ModalityType] = Field(default_factory=list)
|
||||
preferences: dict[str, Any] = Field(default_factory=dict)
|
||||
accessibility_settings: dict[str, Any] = Field(default_factory=dict)
|
||||
started_at: str = Field(default_factory=utc_now_iso)
|
||||
last_activity_at: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
class MultiModalUI:
|
||||
"""
|
||||
Multi-modal user interface for FusionAGI.
|
||||
|
||||
Provides a unified interface that supports multiple sensory modalities
|
||||
simultaneously, allowing users to interact through their preferred
|
||||
combination of text, voice, visual, haptic, gesture, and biometric inputs.
|
||||
|
||||
Features:
|
||||
- Seamless switching between modalities
|
||||
- Simultaneous multi-modal input/output
|
||||
- Accessibility support
|
||||
- Context-aware modality selection
|
||||
- Real-time feedback across all active modalities
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: Orchestrator,
|
||||
conversation_manager: ConversationManager,
|
||||
voice_interface: VoiceInterface | None = None,
|
||||
llm_process_callback: Callable[[str, str, dict[str, Any], Any], str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize multi-modal UI.
|
||||
|
||||
Args:
|
||||
orchestrator: FusionAGI orchestrator for task execution.
|
||||
conversation_manager: Conversation manager for natural language.
|
||||
voice_interface: Voice interface for speech interaction.
|
||||
llm_process_callback: Optional (session_id, user_input, context, style) -> response for converse().
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.conversation_manager = conversation_manager
|
||||
self.voice_interface = voice_interface
|
||||
self._llm_process_callback = llm_process_callback
|
||||
|
||||
self._sessions: dict[str, UserSession] = {}
|
||||
self._interface_adapters: dict[ModalityType, InterfaceAdapter] = {}
|
||||
self._receive_lock = asyncio.Lock()
|
||||
|
||||
# Register voice interface if provided
|
||||
if voice_interface:
|
||||
self._interface_adapters[ModalityType.VOICE] = voice_interface
|
||||
|
||||
logger.info("MultiModalUI initialized")
|
||||
|
||||
# ========== Session Management ==========
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
preferred_modalities: list[ModalityType] | None = None,
|
||||
accessibility_settings: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
user_id: Optional user identifier.
|
||||
preferred_modalities: Preferred interaction modalities.
|
||||
accessibility_settings: Accessibility preferences.
|
||||
|
||||
Returns:
|
||||
Session ID.
|
||||
"""
|
||||
# Create conversation session
|
||||
conv_session_id = self.conversation_manager.create_session(user_id=user_id)
|
||||
|
||||
session = UserSession(
|
||||
user_id=user_id,
|
||||
conversation_session_id=conv_session_id,
|
||||
active_modalities=preferred_modalities or [ModalityType.TEXT],
|
||||
accessibility_settings=accessibility_settings or {},
|
||||
)
|
||||
|
||||
self._sessions[session.session_id] = session
|
||||
|
||||
logger.info(
|
||||
"User session created",
|
||||
extra={
|
||||
"session_id": session.session_id,
|
||||
"user_id": user_id,
|
||||
"modalities": [m.value for m in session.active_modalities],
|
||||
}
|
||||
)
|
||||
|
||||
return session.session_id
|
||||
|
||||
def get_session(self, session_id: str) -> UserSession | None:
|
||||
"""Get user session."""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def active_session_count(self) -> int:
|
||||
"""Return number of active user sessions (for admin panel session_count_callback)."""
|
||||
return len(self._sessions)
|
||||
|
||||
def end_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
End a user session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
|
||||
Returns:
|
||||
True if ended, False if not found.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
# End conversation session
|
||||
if session.conversation_session_id:
|
||||
self.conversation_manager.end_session(session.conversation_session_id)
|
||||
|
||||
del self._sessions[session_id]
|
||||
logger.info("User session ended", extra={"session_id": session_id})
|
||||
return True
|
||||
|
||||
# ========== Modality Management ==========
|
||||
|
||||
def register_interface(self, modality: ModalityType, adapter: InterfaceAdapter) -> None:
|
||||
"""
|
||||
Register an interface adapter for a modality.
|
||||
|
||||
Args:
|
||||
modality: Modality type.
|
||||
adapter: Interface adapter implementation.
|
||||
"""
|
||||
self._interface_adapters[modality] = adapter
|
||||
logger.info("Interface adapter registered", extra={"modality": modality.value})
|
||||
|
||||
def enable_modality(self, session_id: str, modality: ModalityType) -> bool:
|
||||
"""
|
||||
Enable a modality for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
modality: Modality to enable.
|
||||
|
||||
Returns:
|
||||
True if enabled, False if session not found or modality unavailable.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
if modality not in self._interface_adapters:
|
||||
logger.warning(
|
||||
"Modality not available",
|
||||
extra={"modality": modality.value}
|
||||
)
|
||||
return False
|
||||
|
||||
if modality not in session.active_modalities:
|
||||
session.active_modalities.append(modality)
|
||||
logger.info(
|
||||
"Modality enabled",
|
||||
extra={"session_id": session_id, "modality": modality.value}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def disable_modality(self, session_id: str, modality: ModalityType) -> bool:
|
||||
"""
|
||||
Disable a modality for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
modality: Modality to disable.
|
||||
|
||||
Returns:
|
||||
True if disabled, False if session not found.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
if modality in session.active_modalities:
|
||||
session.active_modalities.remove(modality)
|
||||
logger.info(
|
||||
"Modality disabled",
|
||||
extra={"session_id": session_id, "modality": modality.value}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
# ========== User Interaction ==========
|
||||
|
||||
async def send_to_user(
|
||||
self,
|
||||
session_id: str,
|
||||
content: Any,
|
||||
modalities: list[ModalityType] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send content to user through active modalities.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
content: Content to send (will be adapted per modality).
|
||||
modalities: Specific modalities to use (uses active if None).
|
||||
metadata: Additional metadata for the message.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
logger.warning("Session not found", extra={"session_id": session_id})
|
||||
return
|
||||
|
||||
# Determine which modalities to use
|
||||
target_modalities = modalities or session.active_modalities
|
||||
|
||||
# Send through each active modality
|
||||
for modality in target_modalities:
|
||||
adapter = self._interface_adapters.get(modality)
|
||||
if not adapter:
|
||||
continue
|
||||
|
||||
# Create modality-specific message
|
||||
message = InterfaceMessage(
|
||||
id=f"msg_{uuid.uuid4().hex[:8]}",
|
||||
modality=modality,
|
||||
content=self._adapt_content(content, modality),
|
||||
metadata=metadata or {},
|
||||
session_id=session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await adapter.send(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to send through modality",
|
||||
extra={"modality": modality.value, "error": str(e)}
|
||||
)
|
||||
|
||||
async def receive_from_user(
|
||||
self,
|
||||
session_id: str,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> InterfaceMessage | None:
|
||||
"""
|
||||
Receive input from user through any active modality.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
timeout_seconds: Optional timeout for receiving.
|
||||
|
||||
Returns:
|
||||
Received message or None if timeout.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
# Listen on all active modalities (first to respond wins)
|
||||
# TODO: Implement proper async race condition handling
|
||||
for modality in session.active_modalities:
|
||||
adapter = self._interface_adapters.get(modality)
|
||||
if adapter:
|
||||
try:
|
||||
message = await adapter.receive(timeout_seconds)
|
||||
if message:
|
||||
# Update session activity
|
||||
session.last_activity_at = utc_now_iso()
|
||||
return message
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to receive from modality",
|
||||
extra={"modality": modality.value, "error": str(e)}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# ========== Task Interaction ==========
|
||||
|
||||
async def submit_task_interactive(
|
||||
self,
|
||||
session_id: str,
|
||||
goal: str,
|
||||
constraints: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Submit a task and provide interactive feedback.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
goal: Task goal description.
|
||||
constraints: Optional task constraints.
|
||||
|
||||
Returns:
|
||||
Task ID.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
raise ValueError(f"Session not found: {session_id}")
|
||||
|
||||
# Submit task
|
||||
task_id = self.orchestrator.submit_task(
|
||||
goal=goal,
|
||||
constraints=constraints or {},
|
||||
)
|
||||
|
||||
# Send confirmation to user
|
||||
await self.send_to_user(
|
||||
session_id,
|
||||
f"Task submitted: {goal}",
|
||||
metadata={"task_id": task_id, "type": "task_confirmation"},
|
||||
)
|
||||
|
||||
# Subscribe to task events for real-time updates
|
||||
self._subscribe_to_task_updates(session_id, task_id)
|
||||
|
||||
logger.info(
|
||||
"Interactive task submitted",
|
||||
extra={"session_id": session_id, "task_id": task_id}
|
||||
)
|
||||
|
||||
return task_id
|
||||
|
||||
def _subscribe_to_task_updates(self, session_id: str, task_id: str) -> None:
|
||||
"""
|
||||
Subscribe to task updates and relay to user.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
task_id: Task identifier.
|
||||
"""
|
||||
def on_task_update(event_type: str, data: dict[str, Any]) -> None:
|
||||
"""Handle task update event."""
|
||||
if data.get("task_id") != task_id:
|
||||
return
|
||||
|
||||
# Format update message
|
||||
if event_type == "task_state_changed":
|
||||
state = data.get("new_state")
|
||||
message = f"Task {task_id[:8]}: {state}"
|
||||
else:
|
||||
message = f"Task update: {event_type}"
|
||||
|
||||
# Send to user (async in background)
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.create_task(
|
||||
self.send_to_user(
|
||||
session_id,
|
||||
message,
|
||||
metadata={"task_id": task_id, "event_type": event_type},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send task update", extra={"error": str(e)})
|
||||
|
||||
# Subscribe to events
|
||||
self.orchestrator._event_bus.subscribe("task_state_changed", on_task_update)
|
||||
self.orchestrator._event_bus.subscribe("task_step_completed", on_task_update)
|
||||
|
||||
# ========== Conversation Integration ==========
|
||||
|
||||
async def converse(
|
||||
self,
|
||||
session_id: str,
|
||||
user_input: str,
|
||||
) -> str:
|
||||
"""
|
||||
Handle conversational interaction.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
user_input: User's conversational input.
|
||||
|
||||
Returns:
|
||||
Agent's response.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session or not session.conversation_session_id:
|
||||
return "Session not found"
|
||||
|
||||
# Add user turn
|
||||
user_turn = ConversationTurn(
|
||||
session_id=session.conversation_session_id,
|
||||
speaker="user",
|
||||
content=user_input,
|
||||
)
|
||||
self.conversation_manager.add_turn(user_turn)
|
||||
|
||||
context = self.conversation_manager.get_context_summary(session.conversation_session_id)
|
||||
style = self.conversation_manager.get_style_for_session(session.conversation_session_id)
|
||||
if self._llm_process_callback is not None:
|
||||
response = self._llm_process_callback(session_id, user_input, context, style)
|
||||
else:
|
||||
response = f"I understand you said: {user_input}"
|
||||
|
||||
# Add agent turn
|
||||
agent_turn = ConversationTurn(
|
||||
session_id=session.conversation_session_id,
|
||||
speaker="agent",
|
||||
content=response,
|
||||
)
|
||||
self.conversation_manager.add_turn(agent_turn)
|
||||
|
||||
return response
|
||||
|
||||
# ========== Utility Methods ==========
|
||||
|
||||
def _adapt_content(self, content: Any, modality: ModalityType) -> Any:
|
||||
"""
|
||||
Adapt content for a specific modality.
|
||||
|
||||
Args:
|
||||
content: Original content.
|
||||
modality: Target modality.
|
||||
|
||||
Returns:
|
||||
Adapted content.
|
||||
"""
|
||||
# Convert content to appropriate format for modality
|
||||
if modality == ModalityType.TEXT:
|
||||
return str(content)
|
||||
elif modality == ModalityType.VOICE:
|
||||
# For voice, ensure it's text that can be synthesized
|
||||
return str(content)
|
||||
elif modality == ModalityType.VISUAL:
|
||||
# For visual, might need to generate images or format for display
|
||||
return {"type": "text", "content": str(content)}
|
||||
elif modality == ModalityType.HAPTIC:
|
||||
# For haptic, might need to generate vibration patterns
|
||||
return {"pattern": "notification", "intensity": 0.5}
|
||||
else:
|
||||
return content
|
||||
|
||||
def get_available_modalities(self) -> list[ModalityType]:
|
||||
"""Get list of available modalities."""
|
||||
return list(self._interface_adapters.keys())
|
||||
|
||||
def get_session_statistics(self, session_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier.
|
||||
|
||||
Returns:
|
||||
Dictionary with session statistics.
|
||||
"""
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return {}
|
||||
|
||||
# Get conversation history
|
||||
history = []
|
||||
if session.conversation_session_id:
|
||||
history = self.conversation_manager.get_history(session.conversation_session_id)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"user_id": session.user_id,
|
||||
"active_modalities": [m.value for m in session.active_modalities],
|
||||
"conversation_turns": len(history),
|
||||
"started_at": session.started_at,
|
||||
"last_activity_at": session.last_activity_at,
|
||||
}
|
||||
338
fusionagi/interfaces/voice.py
Normal file
338
fusionagi/interfaces/voice.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Voice interface: speech-to-text, text-to-speech, voice library management."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now_iso
|
||||
from fusionagi.interfaces.base import InterfaceAdapter, InterfaceCapabilities, InterfaceMessage, ModalityType
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSAdapter(Protocol):
|
||||
"""Protocol for TTS providers (ElevenLabs, Azure, system, etc.). Integrate by injecting an implementation."""
|
||||
|
||||
async def synthesize(self, text: str, voice_id: str | None = None, **kwargs: Any) -> bytes | None:
|
||||
"""Synthesize text to audio. Returns raw audio bytes or None if not available."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class STTAdapter(Protocol):
|
||||
"""Protocol for STT providers (Whisper, Azure, Google, etc.). Integrate by injecting an implementation."""
|
||||
|
||||
async def transcribe(self, audio_data: bytes | None = None, timeout_seconds: float | None = None, **kwargs: Any) -> str | None:
|
||||
"""Transcribe audio to text. Returns transcribed text or None if timeout/unavailable."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProfile(BaseModel):
|
||||
"""Voice profile for text-to-speech synthesis."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"voice_{uuid.uuid4().hex[:8]}")
|
||||
name: str = Field(description="Human-readable voice name")
|
||||
language: str = Field(default="en-US", description="Language code (e.g., en-US, es-ES)")
|
||||
gender: Literal["male", "female", "neutral"] | None = Field(default=None)
|
||||
age_range: Literal["child", "young_adult", "adult", "senior"] | None = Field(default=None)
|
||||
style: str | None = Field(default=None, description="Voice style (e.g., friendly, professional, calm)")
|
||||
pitch: float = Field(default=1.0, ge=0.5, le=2.0, description="Pitch multiplier")
|
||||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed multiplier")
|
||||
provider: str = Field(default="system", description="TTS provider (e.g., system, elevenlabs, azure)")
|
||||
provider_voice_id: str | None = Field(default=None, description="Provider-specific voice ID")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
class VoiceLibrary:
|
||||
"""
|
||||
Voice library for managing TTS voice profiles.
|
||||
|
||||
Allows admin to add, configure, and organize voice profiles for different
|
||||
agents, contexts, or user preferences.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._voices: dict[str, VoiceProfile] = {}
|
||||
self._default_voice_id: str | None = None
|
||||
logger.info("VoiceLibrary initialized")
|
||||
|
||||
def add_voice(self, profile: VoiceProfile) -> str:
|
||||
"""
|
||||
Add a voice profile to the library.
|
||||
|
||||
Args:
|
||||
profile: Voice profile to add.
|
||||
|
||||
Returns:
|
||||
Voice ID.
|
||||
"""
|
||||
self._voices[profile.id] = profile
|
||||
if self._default_voice_id is None:
|
||||
self._default_voice_id = profile.id
|
||||
logger.info("Voice added", extra={"voice_id": profile.id, "name": profile.name})
|
||||
return profile.id
|
||||
|
||||
def remove_voice(self, voice_id: str) -> bool:
|
||||
"""
|
||||
Remove a voice profile from the library.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to remove.
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found.
|
||||
"""
|
||||
if voice_id in self._voices:
|
||||
del self._voices[voice_id]
|
||||
if self._default_voice_id == voice_id:
|
||||
self._default_voice_id = next(iter(self._voices.keys()), None)
|
||||
logger.info("Voice removed", extra={"voice_id": voice_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_voice(self, voice_id: str) -> VoiceProfile | None:
|
||||
"""Get a voice profile by ID."""
|
||||
return self._voices.get(voice_id)
|
||||
|
||||
def list_voices(
|
||||
self,
|
||||
language: str | None = None,
|
||||
gender: str | None = None,
|
||||
style: str | None = None,
|
||||
) -> list[VoiceProfile]:
|
||||
"""
|
||||
List voice profiles with optional filtering.
|
||||
|
||||
Args:
|
||||
language: Filter by language code.
|
||||
gender: Filter by gender.
|
||||
style: Filter by style.
|
||||
|
||||
Returns:
|
||||
List of matching voice profiles.
|
||||
"""
|
||||
voices = list(self._voices.values())
|
||||
|
||||
if language:
|
||||
voices = [v for v in voices if v.language == language]
|
||||
if gender:
|
||||
voices = [v for v in voices if v.gender == gender]
|
||||
if style:
|
||||
voices = [v for v in voices if v.style == style]
|
||||
|
||||
return voices
|
||||
|
||||
def set_default_voice(self, voice_id: str) -> bool:
|
||||
"""
|
||||
Set the default voice for the library.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to set as default.
|
||||
|
||||
Returns:
|
||||
True if set, False if voice not found.
|
||||
"""
|
||||
if voice_id in self._voices:
|
||||
self._default_voice_id = voice_id
|
||||
logger.info("Default voice set", extra={"voice_id": voice_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_default_voice(self) -> VoiceProfile | None:
|
||||
"""Get the default voice profile."""
|
||||
if self._default_voice_id:
|
||||
return self._voices.get(self._default_voice_id)
|
||||
return None
|
||||
|
||||
def update_voice(self, voice_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a voice profile.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to update.
|
||||
updates: Dictionary of fields to update.
|
||||
|
||||
Returns:
|
||||
True if updated, False if not found.
|
||||
"""
|
||||
if voice_id not in self._voices:
|
||||
return False
|
||||
|
||||
voice = self._voices[voice_id]
|
||||
for key, value in updates.items():
|
||||
if hasattr(voice, key):
|
||||
setattr(voice, key, value)
|
||||
|
||||
logger.info("Voice updated", extra={"voice_id": voice_id, "updates": list(updates.keys())})
|
||||
return True
|
||||
|
||||
|
||||
class VoiceInterface(InterfaceAdapter):
|
||||
"""
|
||||
Voice interface adapter for speech interaction.
|
||||
|
||||
Handles:
|
||||
- Speech-to-text (STT) for user input
|
||||
- Text-to-speech (TTS) for system output
|
||||
- Voice activity detection
|
||||
- Noise cancellation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "voice",
|
||||
voice_library: VoiceLibrary | None = None,
|
||||
stt_provider: str = "whisper",
|
||||
tts_provider: str = "system",
|
||||
tts_adapter: TTSAdapter | None = None,
|
||||
stt_adapter: STTAdapter | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize voice interface.
|
||||
|
||||
Args:
|
||||
name: Interface name.
|
||||
voice_library: Voice library for TTS profiles.
|
||||
stt_provider: Speech-to-text provider (whisper, azure, google, etc.).
|
||||
tts_provider: Text-to-speech provider (system, elevenlabs, azure, etc.).
|
||||
tts_adapter: Optional TTS adapter for synthesis (inject to integrate ElevenLabs, Azure, etc.).
|
||||
stt_adapter: Optional STT adapter for transcription (inject to integrate Whisper, Azure, etc.).
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.voice_library = voice_library or VoiceLibrary()
|
||||
self.stt_provider = stt_provider
|
||||
self.tts_provider = tts_provider
|
||||
self._tts_adapter = tts_adapter
|
||||
self._stt_adapter = stt_adapter
|
||||
self._active_voice_id: str | None = None
|
||||
logger.info(
|
||||
"VoiceInterface initialized",
|
||||
extra={"stt_provider": stt_provider, "tts_provider": tts_provider}
|
||||
)
|
||||
|
||||
def capabilities(self) -> InterfaceCapabilities:
|
||||
"""Return voice interface capabilities."""
|
||||
return InterfaceCapabilities(
|
||||
supported_modalities=[ModalityType.VOICE],
|
||||
supports_streaming=True,
|
||||
supports_interruption=True,
|
||||
supports_multimodal=False,
|
||||
latency_ms=200.0, # Typical voice latency
|
||||
max_concurrent_sessions=10,
|
||||
)
|
||||
|
||||
async def send(self, message: InterfaceMessage) -> None:
|
||||
"""
|
||||
Send voice output (text-to-speech).
|
||||
|
||||
Args:
|
||||
message: Message with text content to synthesize.
|
||||
"""
|
||||
if not self.validate_message(message):
|
||||
logger.warning("Invalid message for voice interface", extra={"modality": message.modality})
|
||||
return
|
||||
|
||||
# Get voice profile
|
||||
voice_id = message.metadata.get("voice_id", self._active_voice_id)
|
||||
voice = None
|
||||
if voice_id:
|
||||
voice = self.voice_library.get_voice(voice_id)
|
||||
if not voice:
|
||||
voice = self.voice_library.get_default_voice()
|
||||
|
||||
text = message.content if isinstance(message.content, str) else str(message.content)
|
||||
voice_id = voice.id if voice else None
|
||||
if self._tts_adapter is not None:
|
||||
try:
|
||||
audio_data = await self._tts_adapter.synthesize(text, voice_id=voice_id)
|
||||
if audio_data:
|
||||
logger.info(
|
||||
"TTS synthesis (adapter)",
|
||||
extra={"text_length": len(text), "voice_id": voice_id, "bytes": len(audio_data)},
|
||||
)
|
||||
# Inject: await self._play_audio(audio_data)
|
||||
except Exception as e:
|
||||
logger.exception("TTS adapter failed", extra={"error": str(e)})
|
||||
else:
|
||||
logger.info(
|
||||
"TTS synthesis (stub; inject tts_adapter for ElevenLabs, Azure, etc.)",
|
||||
extra={"text_length": len(text), "voice_id": voice_id, "provider": self.tts_provider},
|
||||
)
|
||||
|
||||
async def receive(self, timeout_seconds: float | None = None) -> InterfaceMessage | None:
|
||||
"""
|
||||
Receive voice input (speech-to-text).
|
||||
|
||||
Args:
|
||||
timeout_seconds: Optional timeout for listening.
|
||||
|
||||
Returns:
|
||||
Message with transcribed text or None if timeout.
|
||||
"""
|
||||
logger.info("STT listening", extra={"timeout": timeout_seconds, "provider": self.stt_provider})
|
||||
if self._stt_adapter is not None:
|
||||
try:
|
||||
text = await self._stt_adapter.transcribe(audio_data=None, timeout_seconds=timeout_seconds)
|
||||
if text:
|
||||
return InterfaceMessage(
|
||||
id=f"stt_{uuid.uuid4().hex[:8]}",
|
||||
modality=ModalityType.VOICE,
|
||||
content=text,
|
||||
metadata={"provider": self.stt_provider},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("STT adapter failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
def set_active_voice(self, voice_id: str) -> bool:
|
||||
"""
|
||||
Set the active voice for this interface session.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to use.
|
||||
|
||||
Returns:
|
||||
True if voice exists, False otherwise.
|
||||
"""
|
||||
if self.voice_library.get_voice(voice_id):
|
||||
self._active_voice_id = voice_id
|
||||
logger.info("Active voice set", extra={"voice_id": voice_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _synthesize_speech(self, text: str, voice: VoiceProfile | None) -> bytes:
|
||||
"""
|
||||
Synthesize speech from text (to be implemented with actual provider).
|
||||
|
||||
Args:
|
||||
text: Text to synthesize.
|
||||
voice: Voice profile to use.
|
||||
|
||||
Returns:
|
||||
Audio data as bytes.
|
||||
"""
|
||||
# Integrate with TTS provider based on self.tts_provider
|
||||
# - system: Use OS TTS (pyttsx3, etc.)
|
||||
# - elevenlabs: Use ElevenLabs API
|
||||
# - azure: Use Azure Cognitive Services
|
||||
# - google: Use Google Cloud TTS
|
||||
raise NotImplementedError("TTS provider integration required")
|
||||
|
||||
async def _transcribe_speech(self, audio_data: bytes) -> str:
|
||||
"""
|
||||
Transcribe speech to text (to be implemented with actual provider).
|
||||
|
||||
Args:
|
||||
audio_data: Audio data to transcribe.
|
||||
|
||||
Returns:
|
||||
Transcribed text.
|
||||
"""
|
||||
# Integrate with STT provider based on self.stt_provider
|
||||
# - whisper: Use OpenAI Whisper (local or API)
|
||||
# - azure: Use Azure Cognitive Services
|
||||
# - google: Use Google Cloud Speech-to-Text
|
||||
# - deepgram: Use Deepgram API
|
||||
raise NotImplementedError("STT provider integration required")
|
||||
14
fusionagi/maa/__init__.py
Normal file
14
fusionagi/maa/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Manufacturing Authority Add-On: sovereign validation layer for physical-world manufacturing."""
|
||||
|
||||
from fusionagi.maa.gate import MAAGate
|
||||
from fusionagi.maa.schemas.mpc import ManufacturingProofCertificate, MPCId
|
||||
from fusionagi.maa.gap_detection import check_gaps, GapReport, GapClass
|
||||
|
||||
__all__ = [
|
||||
"MAAGate",
|
||||
"ManufacturingProofCertificate",
|
||||
"MPCId",
|
||||
"check_gaps",
|
||||
"GapReport",
|
||||
"GapClass",
|
||||
]
|
||||
35
fusionagi/maa/audit.py
Normal file
35
fusionagi/maa/audit.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Audit and reporting: export MPC and root-cause report format."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.maa.schemas.mpc import ManufacturingProofCertificate
|
||||
from fusionagi.maa.gap_detection import GapReport
|
||||
|
||||
|
||||
def export_mpc_for_audit(cert: ManufacturingProofCertificate) -> dict[str, Any]:
|
||||
"""Export MPC in audit-friendly format."""
|
||||
out: dict[str, Any] = {
|
||||
"mpc_id": cert.mpc_id.value,
|
||||
"mpc_version": cert.mpc_id.version,
|
||||
"decision_lineage": [{"node_id": e.node_id, "family": e.family, "outcome": e.outcome} for e in cert.decision_lineage],
|
||||
"risk_register": [{"risk_id": r.risk_id, "severity": r.severity} for r in cert.risk_register],
|
||||
"metadata": cert.metadata,
|
||||
}
|
||||
if cert.simulation_proof:
|
||||
out["simulation_proof"] = {"proof_id": cert.simulation_proof.proof_id}
|
||||
if cert.process_justification:
|
||||
out["process_justification"] = {"process_type": cert.process_justification.process_type, "eligible": cert.process_justification.eligible}
|
||||
if cert.machine_declaration:
|
||||
out["machine_declaration"] = {"machine_id": cert.machine_declaration.machine_id}
|
||||
return out
|
||||
|
||||
|
||||
def format_root_cause_report(gaps: list[GapReport], tool_name: str = "", context_ref: str = "") -> dict[str, Any]:
|
||||
"""Human-readable root-cause report for gap/tool rejections."""
|
||||
return {
|
||||
"report_type": "maa_root_cause",
|
||||
"tool_name": tool_name,
|
||||
"context_ref": context_ref,
|
||||
"gaps": [{"gap_class": g.gap_class.value, "description": g.description, "required_resolution": g.required_resolution} for g in gaps],
|
||||
"summary": f"{len(gaps)} gap(s) triggered halt.",
|
||||
}
|
||||
87
fusionagi/maa/gap_detection.py
Normal file
87
fusionagi/maa/gap_detection.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Gap detection: active gap classes; any gap triggers halt + root-cause report (no warnings)."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GapClass(str, Enum):
|
||||
"""Active gap classes that trigger immediate halt."""
|
||||
|
||||
MISSING_NUMERIC_BOUNDS = "missing_numeric_bounds"
|
||||
IMPLICIT_TOLERANCES = "implicit_tolerances"
|
||||
UNDEFINED_DATUMS = "undefined_datums"
|
||||
ASSUMED_PROCESSES = "assumed_processes"
|
||||
TOOLPATH_ORPHANING = "toolpath_orphaning"
|
||||
|
||||
|
||||
class GapReport(BaseModel):
|
||||
"""Single gap report: class, root-cause, required resolution."""
|
||||
|
||||
gap_class: GapClass = Field(...)
|
||||
description: str = Field(..., description="Human-readable root cause")
|
||||
context_ref: str | None = Field(default=None)
|
||||
required_resolution: str | None = Field(default=None)
|
||||
|
||||
|
||||
def check_gaps(context: dict[str, Any]) -> list[GapReport]:
|
||||
"""Run gap checks on context; any gap triggers halt. Returns list of gap reports; empty = no gaps."""
|
||||
reports: list[GapReport] = []
|
||||
|
||||
if "numeric_bounds" in context:
|
||||
nb = context["numeric_bounds"]
|
||||
if not isinstance(nb, dict) or not nb:
|
||||
reports.append(
|
||||
GapReport(
|
||||
gap_class=GapClass.MISSING_NUMERIC_BOUNDS,
|
||||
description="Numeric bounds missing or empty",
|
||||
required_resolution="Provide bounded numeric parameters",
|
||||
)
|
||||
)
|
||||
elif context.get("require_numeric_bounds"):
|
||||
reports.append(
|
||||
GapReport(
|
||||
gap_class=GapClass.MISSING_NUMERIC_BOUNDS,
|
||||
description="Numeric bounds required but absent",
|
||||
required_resolution="Provide numeric_bounds in context",
|
||||
)
|
||||
)
|
||||
|
||||
if context.get("require_explicit_tolerances") and not context.get("tolerances"):
|
||||
reports.append(
|
||||
GapReport(
|
||||
gap_class=GapClass.IMPLICIT_TOLERANCES,
|
||||
description="Tolerances must be explicit",
|
||||
required_resolution="Declare tolerances in context",
|
||||
)
|
||||
)
|
||||
|
||||
if context.get("require_datums") and not context.get("datums"):
|
||||
reports.append(
|
||||
GapReport(
|
||||
gap_class=GapClass.UNDEFINED_DATUMS,
|
||||
description="Datums required but undefined",
|
||||
required_resolution="Define datums in context",
|
||||
)
|
||||
)
|
||||
|
||||
if context.get("require_process_type") and not context.get("process_type"):
|
||||
reports.append(
|
||||
GapReport(
|
||||
gap_class=GapClass.ASSUMED_PROCESSES,
|
||||
description="Process type must be declared",
|
||||
required_resolution="Set process_type (additive, subtractive, hybrid)",
|
||||
)
|
||||
)
|
||||
|
||||
if context.get("toolpath_ref") and not context.get("geometry_lineage") and context.get("require_lineage"):
|
||||
reports.append(
|
||||
GapReport(
|
||||
gap_class=GapClass.TOOLPATH_ORPHANING,
|
||||
description="Toolpath must trace to geometry and intent",
|
||||
required_resolution="Provide geometry_lineage and intent_ref",
|
||||
)
|
||||
)
|
||||
|
||||
return reports
|
||||
85
fusionagi/maa/gate.py
Normal file
85
fusionagi/maa/gate.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""MAA Gate: governance integration; MPC check and tool classification for manufacturing tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.maa.gap_detection import check_gaps, GapReport
|
||||
from fusionagi.maa.layers.mpc_authority import MPCAuthority
|
||||
from fusionagi.maa.layers.dlt_engine import DLTEngine
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
# Default manufacturing tool names that require MPC
|
||||
DEFAULT_MANUFACTURING_TOOLS = frozenset({"cnc_emit", "am_slice", "machine_bind"})
|
||||
|
||||
|
||||
class MAAGate:
|
||||
"""
|
||||
Gate for manufacturing tools: (tool_name, args) -> (allowed, sanitized_args | error_message).
|
||||
Compatible with Guardrails.add_check. Manufacturing tools require valid MPC and no gaps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mpc_authority: MPCAuthority,
|
||||
dlt_engine: DLTEngine | None = None,
|
||||
manufacturing_tools: set[str] | frozenset[str] | None = None,
|
||||
) -> None:
|
||||
self._mpc = mpc_authority
|
||||
self._dlt = dlt_engine or DLTEngine()
|
||||
self._manufacturing_tools = manufacturing_tools or DEFAULT_MANUFACTURING_TOOLS
|
||||
|
||||
def is_manufacturing(self, tool_name: str, tool_def: Any = None) -> bool:
|
||||
"""Return True if tool is classified as manufacturing (allowlist or ToolDef scope)."""
|
||||
if tool_def is not None and getattr(tool_def, "manufacturing", False):
|
||||
return True
|
||||
return tool_name in self._manufacturing_tools
|
||||
|
||||
def check(self, tool_name: str, args: dict[str, Any]) -> tuple[bool, dict[str, Any] | str]:
|
||||
"""
|
||||
Pre-check for Guardrails: (tool_name, args) -> (allowed, sanitized_args or error_message).
|
||||
Non-manufacturing tools pass through. Manufacturing tools require mpc_id, valid MPC, no gaps.
|
||||
"""
|
||||
if not self.is_manufacturing(tool_name, None):
|
||||
logger.debug("MAA check pass-through (non-manufacturing)", extra={"tool_name": tool_name})
|
||||
return True, args
|
||||
|
||||
mpc_id_value = args.get("mpc_id") or args.get("mpc_id_value")
|
||||
if not mpc_id_value:
|
||||
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "missing mpc_id"})
|
||||
return False, "MAA: manufacturing tool requires mpc_id in args"
|
||||
|
||||
cert = self._mpc.verify(mpc_id_value)
|
||||
if cert is None:
|
||||
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "invalid or unknown MPC"})
|
||||
return False, f"MAA: invalid or unknown MPC: {mpc_id_value}"
|
||||
|
||||
context: dict[str, Any] = {
|
||||
**args,
|
||||
"mpc_id": mpc_id_value,
|
||||
"mpc_version": cert.mpc_id.version,
|
||||
}
|
||||
gaps = check_gaps(context)
|
||||
if gaps:
|
||||
root_cause = _format_root_cause(gaps)
|
||||
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "gaps", "gap_count": len(gaps)})
|
||||
return False, root_cause
|
||||
|
||||
# Optional DLT evaluation when dlt_contract_id and dlt_context are in args
|
||||
dlt_contract_id = args.get("dlt_contract_id")
|
||||
if dlt_contract_id:
|
||||
dlt_context = args.get("dlt_context") or context
|
||||
ok, cause = self._dlt.evaluate(dlt_contract_id, dlt_context)
|
||||
if not ok:
|
||||
logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "dlt_failed"})
|
||||
return False, f"MAA DLT: {cause}"
|
||||
|
||||
logger.debug("MAA check allowed", extra={"tool_name": tool_name})
|
||||
return True, args
|
||||
|
||||
|
||||
def _format_root_cause(gaps: list[GapReport]) -> str:
|
||||
"""Format gap reports as single root-cause message."""
|
||||
parts = [f"MAA gap: {g.gap_class.value} — {g.description}" for g in gaps]
|
||||
if any(g.required_resolution for g in gaps):
|
||||
parts.append("Required resolution: " + "; ".join(g.required_resolution for g in gaps if g.required_resolution))
|
||||
return " | ".join(parts)
|
||||
25
fusionagi/maa/layers/__init__.py
Normal file
25
fusionagi/maa/layers/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""MAA layers: DLT, intent, geometry, physics, process, machine, toolpath, MPC."""
|
||||
|
||||
from fusionagi.maa.layers.dlt_engine import DLTEngine
|
||||
from fusionagi.maa.layers.mpc_authority import MPCAuthority
|
||||
from fusionagi.maa.layers.intent_engine import IntentEngine
|
||||
from fusionagi.maa.layers.geometry_kernel import GeometryAuthorityInterface, InMemoryGeometryKernel
|
||||
from fusionagi.maa.layers.physics_authority import PhysicsAuthorityInterface, StubPhysicsAuthority
|
||||
from fusionagi.maa.layers.process_authority import ProcessAuthority
|
||||
from fusionagi.maa.layers.machine_binding import MachineBinding, MachineProfile
|
||||
from fusionagi.maa.layers.toolpath_engine import ToolpathEngine, ToolpathArtifact
|
||||
|
||||
__all__ = [
|
||||
"DLTEngine",
|
||||
"MPCAuthority",
|
||||
"IntentEngine",
|
||||
"GeometryAuthorityInterface",
|
||||
"InMemoryGeometryKernel",
|
||||
"PhysicsAuthorityInterface",
|
||||
"StubPhysicsAuthority",
|
||||
"ProcessAuthority",
|
||||
"MachineBinding",
|
||||
"MachineProfile",
|
||||
"ToolpathEngine",
|
||||
"ToolpathArtifact",
|
||||
]
|
||||
68
fusionagi/maa/layers/dlt_engine.py
Normal file
68
fusionagi/maa/layers/dlt_engine.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Deterministic Decision Logic Tree Engine: store and evaluate DLTs; fail-closed."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.maa.schemas.dlt import DLTContract, DLTNode
|
||||
|
||||
|
||||
class DLTEngine:
|
||||
"""Store and evaluate Deterministic Decision Logic Trees; immutable, versioned contracts."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._contracts: dict[str, DLTContract] = {}
|
||||
|
||||
def register(self, contract: DLTContract) -> None:
|
||||
"""Register an immutable DLT contract (by contract_id)."""
|
||||
key = f"{contract.contract_id}@v{contract.version}"
|
||||
self._contracts[key] = contract
|
||||
|
||||
def get(self, contract_id: str, version: int | None = None) -> DLTContract | None:
|
||||
"""Return contract by id; optional version (latest if omitted)."""
|
||||
if version is not None:
|
||||
return self._contracts.get(f"{contract_id}@v{version}")
|
||||
best: DLTContract | None = None
|
||||
for k, c in self._contracts.items():
|
||||
if c.contract_id == contract_id and (best is None or c.version > best.version):
|
||||
best = c
|
||||
return best
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
contract_id: str,
|
||||
context: dict[str, Any],
|
||||
version: int | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""Evaluate DLT from root; deterministic, fail-closed. Return (True, "") or (False, root_cause)."""
|
||||
contract = self.get(contract_id, version)
|
||||
if not contract:
|
||||
return False, f"DLT contract not found: {contract_id}"
|
||||
return self._evaluate_node(contract, contract.root_id, context)
|
||||
|
||||
def _evaluate_node(
|
||||
self,
|
||||
contract: DLTContract,
|
||||
node_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> tuple[bool, str]:
|
||||
node = contract.nodes.get(node_id)
|
||||
if not node:
|
||||
return False, f"DLT node not found: {node_id}"
|
||||
passed = self._check_condition(node, context)
|
||||
if not passed:
|
||||
if node.fail_closed:
|
||||
return False, f"DLT node failed (fail-closed): {node_id} condition={node.condition}"
|
||||
for child_id in node.children:
|
||||
ok, cause = self._evaluate_node(contract, child_id, context)
|
||||
if not ok:
|
||||
return False, cause
|
||||
return True, ""
|
||||
|
||||
def _check_condition(self, node: DLTNode, context: dict[str, Any]) -> bool:
|
||||
"""Evaluate condition; unknown conditions are fail-closed (False)."""
|
||||
if node.condition.startswith("required:"):
|
||||
key = node.condition.split(":", 1)[1].strip()
|
||||
return key in context and context[key] is not None
|
||||
if node.condition == "always":
|
||||
return True
|
||||
# Unknown condition: fail-closed
|
||||
return False
|
||||
81
fusionagi/maa/layers/geometry_kernel.py
Normal file
81
fusionagi/maa/layers/geometry_kernel.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Layer 3 — Geometry Authority Kernel: implicit geometry, constraint solvers, feature lineage."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FeatureLineageEntry(BaseModel):
|
||||
"""Single feature lineage entry: feature -> intent node, physics justification, process eligibility."""
|
||||
|
||||
feature_id: str = Field(...)
|
||||
intent_node_id: str = Field(...)
|
||||
physics_justification_ref: str | None = Field(default=None)
|
||||
process_eligible: bool = Field(default=False)
|
||||
|
||||
|
||||
class GeometryAuthorityInterface(ABC):
|
||||
"""
|
||||
Interface for implicit geometry, constraint solvers, feature lineage.
|
||||
Every geometric feature must map to intent node, physics justification, process eligibility.
|
||||
Orphan geometry is prohibited.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_feature(
|
||||
self,
|
||||
feature_id: str,
|
||||
intent_node_id: str,
|
||||
physics_justification_ref: str | None = None,
|
||||
process_eligible: bool = False,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> FeatureLineageEntry:
|
||||
"""Register a feature with lineage; orphan (no intent) prohibited."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_lineage(self, feature_id: str) -> FeatureLineageEntry | None:
|
||||
"""Return lineage for feature or None."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def validate_no_orphans(self) -> list[str]:
|
||||
"""Return list of feature ids with no valid lineage (orphans); must be empty for MPC."""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryGeometryKernel(GeometryAuthorityInterface):
|
||||
"""
|
||||
In-memory lineage model; no concrete CAD kernel.
|
||||
Only tracks features registered via add_feature; validate_no_orphans returns []
|
||||
since every stored feature has lineage. For a kernel that tracks all feature ids
|
||||
separately, override validate_no_orphans to return ids not in lineage.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lineage: dict[str, FeatureLineageEntry] = {}
|
||||
|
||||
def add_feature(
|
||||
self,
|
||||
feature_id: str,
|
||||
intent_node_id: str,
|
||||
physics_justification_ref: str | None = None,
|
||||
process_eligible: bool = False,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> FeatureLineageEntry:
|
||||
entry = FeatureLineageEntry(
|
||||
feature_id=feature_id,
|
||||
intent_node_id=intent_node_id,
|
||||
physics_justification_ref=physics_justification_ref,
|
||||
process_eligible=process_eligible,
|
||||
)
|
||||
self._lineage[feature_id] = entry
|
||||
return entry
|
||||
|
||||
def get_lineage(self, feature_id: str) -> FeatureLineageEntry | None:
|
||||
return self._lineage.get(feature_id)
|
||||
|
||||
def validate_no_orphans(self) -> list[str]:
|
||||
"""Return []; this stub only tracks registered features, so none are orphans."""
|
||||
return []
|
||||
431
fusionagi/maa/layers/intent_engine.py
Normal file
431
fusionagi/maa/layers/intent_engine.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""Layer 1 — Intent Formalization Engine.
|
||||
|
||||
Responsible for:
|
||||
1. Intent decomposition - breaking natural language into structured requirements
|
||||
2. Requirement typing - classifying requirements (dimensional, load, environmental, process)
|
||||
3. Load case enumeration - identifying operational scenarios
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.maa.schemas.intent import EngineeringIntentGraph, IntentNode, LoadCase, RequirementType
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class IntentIncompleteError(Exception):
|
||||
"""Raised when intent formalization cannot be completed due to missing information."""
|
||||
|
||||
def __init__(self, message: str, missing_fields: list[str] | None = None):
|
||||
self.missing_fields = missing_fields or []
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class IntentEngine:
|
||||
"""
|
||||
Intent decomposition, requirement typing, and load case enumeration.
|
||||
|
||||
Features:
|
||||
- Pattern-based requirement extraction from natural language
|
||||
- Automatic requirement type classification
|
||||
- Load case identification
|
||||
- Environmental bounds extraction
|
||||
- LLM-assisted formalization (optional)
|
||||
"""
|
||||
|
||||
# Patterns for dimensional requirements (measurements, tolerances)
|
||||
DIMENSIONAL_PATTERNS = [
|
||||
r"(\d+(?:\.\d+)?)\s*(mm|cm|m|in|inch|inches|ft|feet)\b",
|
||||
r"tolerance[s]?\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"±\s*(\d+(?:\.\d+)?)",
|
||||
r"(\d+(?:\.\d+)?)\s*×\s*(\d+(?:\.\d+)?)",
|
||||
r"diameter\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"radius\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"thickness\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"length\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"width\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"height\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
]
|
||||
|
||||
# Patterns for load requirements (forces, pressures, stresses)
|
||||
LOAD_PATTERNS = [
|
||||
r"(\d+(?:\.\d+)?)\s*(N|kN|MN|lb|lbf|kg|kgf)\b",
|
||||
r"(\d+(?:\.\d+)?)\s*(MPa|GPa|Pa|psi|ksi)\b",
|
||||
r"load\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"force\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"stress\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"pressure\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"factor\s*of\s*safety\s*(?:of\s*)?(\d+(?:\.\d+)?)",
|
||||
r"yield\s*strength",
|
||||
r"tensile\s*strength",
|
||||
r"fatigue\s*(?:life|limit|strength)",
|
||||
]
|
||||
|
||||
# Patterns for environmental requirements
|
||||
ENVIRONMENTAL_PATTERNS = [
|
||||
r"(\d+(?:\.\d+)?)\s*(?:°|deg|degrees?)?\s*(C|F|K|Celsius|Fahrenheit|Kelvin)\b",
|
||||
r"temperature\s*(?:range|of)?\s*(\d+)",
|
||||
r"humidity\s*(?:of\s*)?(\d+)",
|
||||
r"corrosion\s*resist",
|
||||
r"UV\s*resist",
|
||||
r"water\s*(?:proof|resist)",
|
||||
r"chemical\s*resist",
|
||||
r"outdoor",
|
||||
r"marine",
|
||||
r"aerospace",
|
||||
]
|
||||
|
||||
# Patterns for process requirements
|
||||
PROCESS_PATTERNS = [
|
||||
r"CNC|machining|milling|turning|drilling",
|
||||
r"3D\s*print|additive|FDM|SLA|SLS|DMLS",
|
||||
r"cast|injection\s*mold|die\s*cast",
|
||||
r"weld|braze|solder",
|
||||
r"heat\s*treat|anneal|harden|temper",
|
||||
r"surface\s*finish|polish|anodize|plate",
|
||||
r"assembly|sub-assembly",
|
||||
r"material:\s*(\w+)",
|
||||
r"aluminum|steel|titanium|plastic|composite",
|
||||
]
|
||||
|
||||
# Load case indicator patterns
|
||||
LOAD_CASE_PATTERNS = [
|
||||
r"(?:during|under|in)\s+(\w+(?:\s+\w+)?)\s+(?:conditions?|operation|mode)",
|
||||
r"(\w+)\s+load\s+case",
|
||||
r"(?:static|dynamic|cyclic|impact|thermal)\s+load",
|
||||
r"(?:normal|extreme|emergency|failure)\s+(?:operation|conditions?|mode)",
|
||||
r"operating\s+(?:at|under|in)",
|
||||
]
|
||||
|
||||
def __init__(self, llm_adapter: Any | None = None):
|
||||
"""
|
||||
Initialize the IntentEngine.
|
||||
|
||||
Args:
|
||||
llm_adapter: Optional LLM adapter for enhanced natural language processing.
|
||||
"""
|
||||
self._llm = llm_adapter
|
||||
|
||||
def formalize(
|
||||
self,
|
||||
intent_id: str,
|
||||
natural_language: str | None = None,
|
||||
file_refs: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
use_llm: bool = True,
|
||||
) -> EngineeringIntentGraph:
|
||||
"""
|
||||
Formalize engineering intent from natural language and file references.
|
||||
|
||||
Args:
|
||||
intent_id: Unique identifier for this intent.
|
||||
natural_language: Natural language description of requirements.
|
||||
file_refs: References to CAD files, specifications, etc.
|
||||
metadata: Additional metadata.
|
||||
use_llm: Whether to use LLM for enhanced processing (if available).
|
||||
|
||||
Returns:
|
||||
EngineeringIntentGraph with extracted requirements.
|
||||
|
||||
Raises:
|
||||
IntentIncompleteError: If required information is missing.
|
||||
"""
|
||||
if not intent_id:
|
||||
raise IntentIncompleteError("intent_id required", ["intent_id"])
|
||||
|
||||
if not natural_language and not file_refs:
|
||||
raise IntentIncompleteError(
|
||||
"At least one of natural_language or file_refs required",
|
||||
["natural_language", "file_refs"],
|
||||
)
|
||||
|
||||
nodes: list[IntentNode] = []
|
||||
load_cases: list[LoadCase] = []
|
||||
environmental_bounds: dict[str, Any] = {}
|
||||
|
||||
# Process natural language if provided
|
||||
if natural_language:
|
||||
# Use LLM if available and requested
|
||||
if use_llm and self._llm:
|
||||
llm_result = self._formalize_with_llm(intent_id, natural_language)
|
||||
if llm_result:
|
||||
return llm_result
|
||||
|
||||
# Fall back to pattern-based extraction
|
||||
extracted = self._extract_requirements(intent_id, natural_language)
|
||||
nodes.extend(extracted["nodes"])
|
||||
load_cases.extend(extracted["load_cases"])
|
||||
environmental_bounds.update(extracted["environmental_bounds"])
|
||||
|
||||
# Process file references
|
||||
if file_refs:
|
||||
for ref in file_refs:
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_file_{uuid.uuid4().hex[:8]}",
|
||||
requirement_type=RequirementType.OTHER,
|
||||
description=f"Reference: {ref}",
|
||||
metadata={"file_ref": ref},
|
||||
)
|
||||
)
|
||||
|
||||
# If no nodes were extracted, create a general requirement
|
||||
if not nodes and natural_language:
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_general_0",
|
||||
requirement_type=RequirementType.OTHER,
|
||||
description=natural_language[:500],
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Intent formalized",
|
||||
extra={
|
||||
"intent_id": intent_id,
|
||||
"num_nodes": len(nodes),
|
||||
"num_load_cases": len(load_cases),
|
||||
},
|
||||
)
|
||||
|
||||
return EngineeringIntentGraph(
|
||||
intent_id=intent_id,
|
||||
nodes=nodes,
|
||||
load_cases=load_cases,
|
||||
environmental_bounds=environmental_bounds,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def _extract_requirements(
|
||||
self,
|
||||
intent_id: str,
|
||||
text: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract requirements from text using pattern matching.
|
||||
|
||||
Returns dict with nodes, load_cases, and environmental_bounds.
|
||||
"""
|
||||
nodes: list[IntentNode] = []
|
||||
load_cases: list[LoadCase] = []
|
||||
environmental_bounds: dict[str, Any] = {}
|
||||
|
||||
# Split into sentences for processing
|
||||
sentences = re.split(r'[.!?]+', text)
|
||||
|
||||
node_counter = 0
|
||||
load_case_counter = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
# Check for dimensional requirements
|
||||
for pattern in self.DIMENSIONAL_PATTERNS:
|
||||
if re.search(pattern, sentence, re.IGNORECASE):
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_dim_{node_counter}",
|
||||
requirement_type=RequirementType.DIMENSIONAL,
|
||||
description=sentence,
|
||||
metadata={"pattern": "dimensional"},
|
||||
)
|
||||
)
|
||||
node_counter += 1
|
||||
break
|
||||
|
||||
# Check for load requirements
|
||||
for pattern in self.LOAD_PATTERNS:
|
||||
if re.search(pattern, sentence, re.IGNORECASE):
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_load_{node_counter}",
|
||||
requirement_type=RequirementType.LOAD,
|
||||
description=sentence,
|
||||
metadata={"pattern": "load"},
|
||||
)
|
||||
)
|
||||
node_counter += 1
|
||||
break
|
||||
|
||||
# Check for environmental requirements
|
||||
for pattern in self.ENVIRONMENTAL_PATTERNS:
|
||||
match = re.search(pattern, sentence, re.IGNORECASE)
|
||||
if match:
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_env_{node_counter}",
|
||||
requirement_type=RequirementType.ENVIRONMENTAL,
|
||||
description=sentence,
|
||||
metadata={"pattern": "environmental"},
|
||||
)
|
||||
)
|
||||
node_counter += 1
|
||||
|
||||
# Extract specific bounds if possible
|
||||
if "temperature" in sentence.lower():
|
||||
temp_match = re.search(r"(-?\d+(?:\.\d+)?)", sentence)
|
||||
if temp_match:
|
||||
environmental_bounds["temperature"] = float(temp_match.group(1))
|
||||
break
|
||||
|
||||
# Check for process requirements
|
||||
for pattern in self.PROCESS_PATTERNS:
|
||||
if re.search(pattern, sentence, re.IGNORECASE):
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_proc_{node_counter}",
|
||||
requirement_type=RequirementType.PROCESS,
|
||||
description=sentence,
|
||||
metadata={"pattern": "process"},
|
||||
)
|
||||
)
|
||||
node_counter += 1
|
||||
break
|
||||
|
||||
# Check for load cases
|
||||
for pattern in self.LOAD_CASE_PATTERNS:
|
||||
match = re.search(pattern, sentence, re.IGNORECASE)
|
||||
if match:
|
||||
load_case_desc = match.group(0) if match.group(0) else sentence
|
||||
load_cases.append(
|
||||
LoadCase(
|
||||
load_case_id=f"{intent_id}_lc_{load_case_counter}",
|
||||
description=load_case_desc,
|
||||
metadata={"source_sentence": sentence},
|
||||
)
|
||||
)
|
||||
load_case_counter += 1
|
||||
break
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"load_cases": load_cases,
|
||||
"environmental_bounds": environmental_bounds,
|
||||
}
|
||||
|
||||
def _formalize_with_llm(
|
||||
self,
|
||||
intent_id: str,
|
||||
natural_language: str,
|
||||
) -> EngineeringIntentGraph | None:
|
||||
"""
|
||||
Use LLM to extract structured requirements from natural language.
|
||||
|
||||
Returns None if LLM processing fails (falls back to pattern matching).
|
||||
"""
|
||||
if not self._llm:
|
||||
return None
|
||||
|
||||
import json
|
||||
|
||||
prompt = f"""Extract engineering requirements from the following text.
|
||||
Return a JSON object with:
|
||||
- "nodes": list of requirements, each with:
|
||||
- "requirement_type": one of "dimensional", "load", "environmental", "process", "other"
|
||||
- "description": the requirement text
|
||||
- "load_cases": list of operational scenarios, each with:
|
||||
- "description": the scenario description
|
||||
- "environmental_bounds": dict of environmental limits (e.g., {{"temperature_max": 85, "humidity_max": 95}})
|
||||
|
||||
Text: {natural_language[:2000]}
|
||||
|
||||
Return only valid JSON, no markdown."""
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an engineering requirements extraction system."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# Try structured output if available
|
||||
if hasattr(self._llm, "complete_structured"):
|
||||
result = self._llm.complete_structured(messages)
|
||||
if result:
|
||||
return self._parse_llm_result(intent_id, result)
|
||||
|
||||
# Fall back to text completion
|
||||
raw = self._llm.complete(messages)
|
||||
if raw:
|
||||
# Clean up response
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("```")[1]
|
||||
if raw.startswith("json"):
|
||||
raw = raw[4:]
|
||||
result = json.loads(raw)
|
||||
return self._parse_llm_result(intent_id, result)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM formalization failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _parse_llm_result(
|
||||
self,
|
||||
intent_id: str,
|
||||
result: dict[str, Any],
|
||||
) -> EngineeringIntentGraph:
|
||||
"""Parse LLM result into EngineeringIntentGraph."""
|
||||
nodes = []
|
||||
for i, node_data in enumerate(result.get("nodes", [])):
|
||||
req_type_str = node_data.get("requirement_type", "other")
|
||||
try:
|
||||
req_type = RequirementType(req_type_str)
|
||||
except ValueError:
|
||||
req_type = RequirementType.OTHER
|
||||
|
||||
nodes.append(
|
||||
IntentNode(
|
||||
node_id=f"{intent_id}_llm_{i}",
|
||||
requirement_type=req_type,
|
||||
description=node_data.get("description", ""),
|
||||
metadata={"source": "llm"},
|
||||
)
|
||||
)
|
||||
|
||||
load_cases = []
|
||||
for i, lc_data in enumerate(result.get("load_cases", [])):
|
||||
load_cases.append(
|
||||
LoadCase(
|
||||
load_case_id=f"{intent_id}_lc_llm_{i}",
|
||||
description=lc_data.get("description", ""),
|
||||
metadata={"source": "llm"},
|
||||
)
|
||||
)
|
||||
|
||||
environmental_bounds = result.get("environmental_bounds", {})
|
||||
|
||||
return EngineeringIntentGraph(
|
||||
intent_id=intent_id,
|
||||
nodes=nodes,
|
||||
load_cases=load_cases,
|
||||
environmental_bounds=environmental_bounds,
|
||||
metadata={"formalization_source": "llm"},
|
||||
)
|
||||
|
||||
def validate_completeness(self, graph: EngineeringIntentGraph) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate that an intent graph has sufficient information.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_complete, list_of_missing_items)
|
||||
"""
|
||||
missing = []
|
||||
|
||||
if not graph.nodes:
|
||||
missing.append("No requirements extracted")
|
||||
|
||||
# Check for at least one dimensional or load requirement for manufacturing
|
||||
has_dimensional = any(n.requirement_type == RequirementType.DIMENSIONAL for n in graph.nodes)
|
||||
has_load = any(n.requirement_type == RequirementType.LOAD for n in graph.nodes)
|
||||
|
||||
if not has_dimensional:
|
||||
missing.append("No dimensional requirements specified")
|
||||
|
||||
# Load cases are recommended but not required
|
||||
if not graph.load_cases:
|
||||
logger.info("No load cases specified for intent", extra={"intent_id": graph.intent_id})
|
||||
|
||||
return len(missing) == 0, missing
|
||||
33
fusionagi/maa/layers/machine_binding.py
Normal file
33
fusionagi/maa/layers/machine_binding.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Layer 6 — Machine Binding & Personality Profiles."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MachineProfile(BaseModel):
|
||||
"""Machine personality profile: limits, historical deviation models."""
|
||||
|
||||
machine_id: str = Field(..., description="Bound machine id")
|
||||
limits_ref: str | None = Field(default=None)
|
||||
deviation_model_ref: str | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MachineBinding:
|
||||
"""Each design binds to a specific machine with known limits. No abstraction without binding."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._profiles: dict[str, MachineProfile] = {}
|
||||
|
||||
def register(self, profile: MachineProfile) -> None:
|
||||
"""Register a machine profile."""
|
||||
self._profiles[profile.machine_id] = profile
|
||||
|
||||
def get(self, machine_id: str) -> MachineProfile | None:
|
||||
"""Return profile for machine or None."""
|
||||
return self._profiles.get(machine_id)
|
||||
|
||||
def resolve(self, machine_id: str) -> MachineProfile | None:
|
||||
"""Resolve machine binding; reject if unknown (no abstraction)."""
|
||||
return self.get(machine_id)
|
||||
65
fusionagi/maa/layers/mpc_authority.py
Normal file
65
fusionagi/maa/layers/mpc_authority.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""MPC Authority: issue and verify Manufacturing Proof Certificates; immutable, versioned."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.maa.schemas.mpc import (
|
||||
ManufacturingProofCertificate,
|
||||
MPCId,
|
||||
DecisionLineageEntry,
|
||||
SimulationProof,
|
||||
ProcessJustification,
|
||||
MachineDeclaration,
|
||||
RiskRegisterEntry,
|
||||
)
|
||||
from fusionagi.maa.versioning import VersionStore
|
||||
|
||||
|
||||
class MPCAuthority:
|
||||
"""Central issue and verify MPCs; immutable, versioned."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store = VersionStore()
|
||||
self._by_value: dict[str, ManufacturingProofCertificate] = {} # mpc_id.value -> cert
|
||||
|
||||
def issue(
|
||||
self,
|
||||
mpc_id_value: str,
|
||||
decision_lineage: list[DecisionLineageEntry] | None = None,
|
||||
simulation_proof: SimulationProof | None = None,
|
||||
process_justification: ProcessJustification | None = None,
|
||||
machine_declaration: MachineDeclaration | None = None,
|
||||
risk_register: list[RiskRegisterEntry] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ManufacturingProofCertificate:
|
||||
"""Issue a new MPC; version auto-incremented."""
|
||||
latest = self._store.get_latest_version(mpc_id_value)
|
||||
version = (latest or 0) + 1
|
||||
mpc_id = MPCId(value=mpc_id_value, version=version)
|
||||
cert = ManufacturingProofCertificate(
|
||||
mpc_id=mpc_id,
|
||||
decision_lineage=decision_lineage or [],
|
||||
simulation_proof=simulation_proof,
|
||||
process_justification=process_justification,
|
||||
machine_declaration=machine_declaration,
|
||||
risk_register=risk_register or [],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._store.put(mpc_id_value, version, cert)
|
||||
self._by_value[mpc_id_value] = cert
|
||||
return cert
|
||||
|
||||
def verify(self, mpc_id: str | MPCId, version: int | None = None) -> ManufacturingProofCertificate | None:
|
||||
"""Verify and return MPC if valid; None if not found or invalid."""
|
||||
value = mpc_id.value if isinstance(mpc_id, MPCId) else mpc_id
|
||||
cert = self._store.get(value, version) if version is not None else self._by_value.get(value)
|
||||
if cert is None and version is None:
|
||||
cert = self._store.get(value, self._store.get_latest_version(value))
|
||||
return cert
|
||||
|
||||
def get(self, mpc_id_value: str, version: int | None = None) -> ManufacturingProofCertificate | None:
|
||||
"""Return stored MPC by value and optional version."""
|
||||
if version is not None:
|
||||
return self._store.get(mpc_id_value, version)
|
||||
return self._by_value.get(mpc_id_value) or self._store.get(
|
||||
mpc_id_value, self._store.get_latest_version(mpc_id_value)
|
||||
)
|
||||
449
fusionagi/maa/layers/physics_authority.py
Normal file
449
fusionagi/maa/layers/physics_authority.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""Layer 4 — Physics Closure & Simulation Authority.
|
||||
|
||||
Responsible for:
|
||||
- Governing equation selection (structural, thermal, fluid)
|
||||
- Boundary condition enforcement
|
||||
- Safety factor calculation and validation
|
||||
- Failure mode completeness analysis
|
||||
- Simulation binding (simulations are binding, not illustrative)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class PhysicsUnderdefinedError(Exception):
|
||||
"""Failure state: physics not fully defined."""
|
||||
|
||||
def __init__(self, message: str, missing_data: list[str] | None = None):
|
||||
self.missing_data = missing_data or []
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ProofResult(str, Enum):
|
||||
"""Result of physics validation."""
|
||||
|
||||
PROOF = "proof"
|
||||
PHYSICS_UNDEFINED = "physics_underdefined"
|
||||
VALIDATION_FAILED = "validation_failed"
|
||||
|
||||
|
||||
class PhysicsProof(BaseModel):
|
||||
"""Binding simulation proof reference."""
|
||||
|
||||
proof_id: str = Field(...)
|
||||
governing_equations: str | None = Field(default=None)
|
||||
boundary_conditions_ref: str | None = Field(default=None)
|
||||
safety_factor: float | None = Field(default=None)
|
||||
failure_modes_covered: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
validation_status: str = Field(default="validated")
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PhysicsAuthorityInterface(ABC):
|
||||
"""
|
||||
Abstract interface for physics validation.
|
||||
|
||||
Governing equation selection, boundary condition enforcement, safety factor declaration,
|
||||
failure-mode completeness. Simulations are binding, not illustrative.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate_physics(
|
||||
self,
|
||||
design_ref: str,
|
||||
load_cases: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> PhysicsProof | None:
|
||||
"""
|
||||
Validate physics for design; return Proof or None (PhysicsUnderdefined).
|
||||
Raises PhysicsUnderdefinedError if required data missing.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Common material properties database (simplified)
|
||||
MATERIAL_PROPERTIES: dict[str, dict[str, float]] = {
|
||||
"aluminum_6061": {
|
||||
"yield_strength_mpa": 276,
|
||||
"ultimate_strength_mpa": 310,
|
||||
"elastic_modulus_gpa": 68.9,
|
||||
"density_kg_m3": 2700,
|
||||
"poisson_ratio": 0.33,
|
||||
"thermal_expansion_per_c": 23.6e-6,
|
||||
"max_service_temp_c": 150,
|
||||
},
|
||||
"steel_4140": {
|
||||
"yield_strength_mpa": 655,
|
||||
"ultimate_strength_mpa": 1020,
|
||||
"elastic_modulus_gpa": 205,
|
||||
"density_kg_m3": 7850,
|
||||
"poisson_ratio": 0.29,
|
||||
"thermal_expansion_per_c": 12.3e-6,
|
||||
"max_service_temp_c": 400,
|
||||
},
|
||||
"titanium_ti6al4v": {
|
||||
"yield_strength_mpa": 880,
|
||||
"ultimate_strength_mpa": 950,
|
||||
"elastic_modulus_gpa": 113.8,
|
||||
"density_kg_m3": 4430,
|
||||
"poisson_ratio": 0.34,
|
||||
"thermal_expansion_per_c": 8.6e-6,
|
||||
"max_service_temp_c": 350,
|
||||
},
|
||||
"pla_plastic": {
|
||||
"yield_strength_mpa": 60,
|
||||
"ultimate_strength_mpa": 65,
|
||||
"elastic_modulus_gpa": 3.5,
|
||||
"density_kg_m3": 1240,
|
||||
"poisson_ratio": 0.36,
|
||||
"thermal_expansion_per_c": 68e-6,
|
||||
"max_service_temp_c": 55,
|
||||
},
|
||||
"abs_plastic": {
|
||||
"yield_strength_mpa": 40,
|
||||
"ultimate_strength_mpa": 44,
|
||||
"elastic_modulus_gpa": 2.3,
|
||||
"density_kg_m3": 1050,
|
||||
"poisson_ratio": 0.35,
|
||||
"thermal_expansion_per_c": 90e-6,
|
||||
"max_service_temp_c": 85,
|
||||
},
|
||||
}
|
||||
|
||||
# Standard failure modes to check
|
||||
STANDARD_FAILURE_MODES = [
|
||||
"yield_failure",
|
||||
"ultimate_failure",
|
||||
"buckling",
|
||||
"fatigue",
|
||||
"creep",
|
||||
"thermal_distortion",
|
||||
"vibration_resonance",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadCaseResult:
|
||||
"""Result of validating a single load case."""
|
||||
|
||||
load_case_id: str
|
||||
max_stress_mpa: float
|
||||
safety_factor: float
|
||||
passed: bool
|
||||
failure_mode: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PhysicsAuthority(PhysicsAuthorityInterface):
|
||||
"""
|
||||
Physics validation authority with actual validation logic.
|
||||
|
||||
Features:
|
||||
- Material property validation
|
||||
- Load case analysis
|
||||
- Safety factor calculation
|
||||
- Failure mode coverage analysis
|
||||
- Governing equation selection based on load types
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
required_safety_factor: float = 2.0,
|
||||
material_db: dict[str, dict[str, float]] | None = None,
|
||||
custom_failure_modes: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the PhysicsAuthority.
|
||||
|
||||
Args:
|
||||
required_safety_factor: Minimum required safety factor (default 2.0).
|
||||
material_db: Custom material properties database.
|
||||
custom_failure_modes: Additional failure modes to check.
|
||||
"""
|
||||
self._required_sf = required_safety_factor
|
||||
self._materials = material_db or MATERIAL_PROPERTIES
|
||||
self._failure_modes = list(STANDARD_FAILURE_MODES)
|
||||
if custom_failure_modes:
|
||||
self._failure_modes.extend(custom_failure_modes)
|
||||
|
||||
def validate_physics(
|
||||
self,
|
||||
design_ref: str,
|
||||
load_cases: list[dict[str, Any]] | None = None,
|
||||
material: str | None = None,
|
||||
dimensions: dict[str, float] | None = None,
|
||||
boundary_conditions: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> PhysicsProof | None:
|
||||
"""
|
||||
Validate physics for a design.
|
||||
|
||||
Args:
|
||||
design_ref: Reference to the design being validated.
|
||||
load_cases: List of load cases to validate against.
|
||||
material: Material identifier (must be in material database).
|
||||
dimensions: Key dimensions for stress calculation.
|
||||
boundary_conditions: Boundary condition specification.
|
||||
**kwargs: Additional parameters.
|
||||
|
||||
Returns:
|
||||
PhysicsProof if validation passes, None if physics underdefined.
|
||||
|
||||
Raises:
|
||||
PhysicsUnderdefinedError: If critical data is missing.
|
||||
"""
|
||||
missing_data = []
|
||||
|
||||
if not design_ref:
|
||||
missing_data.append("design_ref")
|
||||
if not material:
|
||||
missing_data.append("material")
|
||||
if not load_cases:
|
||||
missing_data.append("load_cases")
|
||||
|
||||
if missing_data:
|
||||
raise PhysicsUnderdefinedError(
|
||||
f"Physics validation requires: {', '.join(missing_data)}",
|
||||
missing_data=missing_data,
|
||||
)
|
||||
|
||||
# Get material properties
|
||||
mat_props = self._materials.get(material.lower().replace(" ", "_"))
|
||||
if not mat_props:
|
||||
raise PhysicsUnderdefinedError(
|
||||
f"Unknown material: {material}. Available: {list(self._materials.keys())}",
|
||||
missing_data=["material_properties"],
|
||||
)
|
||||
|
||||
# Validate each load case
|
||||
load_case_results: list[LoadCaseResult] = []
|
||||
min_safety_factor = float("inf")
|
||||
warnings: list[str] = []
|
||||
failure_modes_covered: list[str] = []
|
||||
|
||||
for lc in load_cases:
|
||||
result = self._validate_load_case(lc, mat_props, dimensions)
|
||||
load_case_results.append(result)
|
||||
|
||||
if result.safety_factor < min_safety_factor:
|
||||
min_safety_factor = result.safety_factor
|
||||
|
||||
if not result.passed:
|
||||
warnings.append(
|
||||
f"Load case '{result.load_case_id}' failed: {result.failure_mode}"
|
||||
)
|
||||
|
||||
# Track failure modes analyzed
|
||||
if result.failure_mode and result.failure_mode not in failure_modes_covered:
|
||||
failure_modes_covered.append(result.failure_mode)
|
||||
|
||||
# Determine governing equations based on load types
|
||||
governing_equations = self._select_governing_equations(load_cases)
|
||||
|
||||
# Check minimum required failure modes
|
||||
required_modes = ["yield_failure", "ultimate_failure"]
|
||||
for mode in required_modes:
|
||||
if mode not in failure_modes_covered:
|
||||
failure_modes_covered.append(mode) # Basic checks are always done
|
||||
|
||||
# Generate proof ID based on inputs
|
||||
proof_hash = hashlib.sha256(
|
||||
f"{design_ref}:{material}:{load_cases}".encode()
|
||||
).hexdigest()[:16]
|
||||
proof_id = f"proof_{design_ref}_{proof_hash}"
|
||||
|
||||
# Determine validation status
|
||||
validation_status = "validated"
|
||||
if min_safety_factor < self._required_sf:
|
||||
validation_status = "insufficient_safety_factor"
|
||||
warnings.append(
|
||||
f"Safety factor {min_safety_factor:.2f} < required {self._required_sf}"
|
||||
)
|
||||
|
||||
if any(not r.passed for r in load_case_results):
|
||||
validation_status = "load_case_failure"
|
||||
|
||||
logger.info(
|
||||
"Physics validation completed",
|
||||
extra={
|
||||
"design_ref": design_ref,
|
||||
"material": material,
|
||||
"min_sf": min_safety_factor,
|
||||
"status": validation_status,
|
||||
"num_load_cases": len(load_cases),
|
||||
},
|
||||
)
|
||||
|
||||
return PhysicsProof(
|
||||
proof_id=proof_id,
|
||||
governing_equations=governing_equations,
|
||||
boundary_conditions_ref=str(boundary_conditions) if boundary_conditions else None,
|
||||
safety_factor=min_safety_factor if min_safety_factor != float("inf") else None,
|
||||
failure_modes_covered=failure_modes_covered,
|
||||
metadata={
|
||||
"material": material,
|
||||
"material_properties": mat_props,
|
||||
"load_case_results": [
|
||||
{
|
||||
"id": r.load_case_id,
|
||||
"max_stress_mpa": r.max_stress_mpa,
|
||||
"sf": r.safety_factor,
|
||||
"passed": r.passed,
|
||||
}
|
||||
for r in load_case_results
|
||||
],
|
||||
"required_safety_factor": self._required_sf,
|
||||
},
|
||||
validation_status=validation_status,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def _validate_load_case(
|
||||
self,
|
||||
load_case: dict[str, Any],
|
||||
mat_props: dict[str, float],
|
||||
dimensions: dict[str, float] | None,
|
||||
) -> LoadCaseResult:
|
||||
"""Validate a single load case."""
|
||||
lc_id = load_case.get("id", str(uuid.uuid4())[:8])
|
||||
|
||||
# Extract load parameters
|
||||
force_n = load_case.get("force_n", 0)
|
||||
moment_nm = load_case.get("moment_nm", 0)
|
||||
pressure_mpa = load_case.get("pressure_mpa", 0)
|
||||
temperature_c = load_case.get("temperature_c", 25)
|
||||
|
||||
# Get material limits
|
||||
yield_strength = mat_props.get("yield_strength_mpa", 100)
|
||||
ultimate_strength = mat_props.get("ultimate_strength_mpa", 150)
|
||||
max_temp = mat_props.get("max_service_temp_c", 100)
|
||||
|
||||
# Calculate stress (simplified - assumes basic geometry)
|
||||
area_mm2 = 100.0 # Default cross-sectional area
|
||||
if dimensions:
|
||||
width = dimensions.get("width_mm", 10)
|
||||
height = dimensions.get("height_mm", 10)
|
||||
area_mm2 = width * height
|
||||
|
||||
# Basic stress calculation
|
||||
axial_stress = force_n / area_mm2 if area_mm2 > 0 else 0
|
||||
bending_stress = 0
|
||||
if moment_nm and dimensions:
|
||||
# Simplified bending: M*c/I where c = height/2, I = width*height^3/12
|
||||
height = dimensions.get("height_mm", 10)
|
||||
width = dimensions.get("width_mm", 10)
|
||||
c = height / 2
|
||||
i = width * (height ** 3) / 12
|
||||
bending_stress = (moment_nm * 1000 * c) / i if i > 0 else 0
|
||||
|
||||
# Combined stress (von Mises simplified for 1D)
|
||||
max_stress = abs(axial_stress) + abs(bending_stress) + pressure_mpa
|
||||
|
||||
# Calculate safety factors
|
||||
yield_sf = yield_strength / max_stress if max_stress > 0 else float("inf")
|
||||
ultimate_sf = ultimate_strength / max_stress if max_stress > 0 else float("inf")
|
||||
|
||||
# Check temperature limits
|
||||
temp_ok = temperature_c <= max_temp
|
||||
|
||||
# Determine if load case passes
|
||||
passed = (
|
||||
yield_sf >= self._required_sf
|
||||
and ultimate_sf >= self._required_sf
|
||||
and temp_ok
|
||||
)
|
||||
|
||||
failure_mode = None
|
||||
if yield_sf < self._required_sf:
|
||||
failure_mode = "yield_failure"
|
||||
elif ultimate_sf < self._required_sf:
|
||||
failure_mode = "ultimate_failure"
|
||||
elif not temp_ok:
|
||||
failure_mode = "thermal_failure"
|
||||
|
||||
return LoadCaseResult(
|
||||
load_case_id=lc_id,
|
||||
max_stress_mpa=max_stress,
|
||||
safety_factor=min(yield_sf, ultimate_sf),
|
||||
passed=passed,
|
||||
failure_mode=failure_mode,
|
||||
details={
|
||||
"axial_stress_mpa": axial_stress,
|
||||
"bending_stress_mpa": bending_stress,
|
||||
"yield_sf": yield_sf,
|
||||
"ultimate_sf": ultimate_sf,
|
||||
"temperature_ok": temp_ok,
|
||||
},
|
||||
)
|
||||
|
||||
def _select_governing_equations(self, load_cases: list[dict[str, Any]]) -> str:
|
||||
"""Select appropriate governing equations based on load types."""
|
||||
equations = []
|
||||
|
||||
# Check load types
|
||||
has_static = any(lc.get("type") == "static" or lc.get("force_n") for lc in load_cases)
|
||||
has_thermal = any(lc.get("temperature_c") for lc in load_cases)
|
||||
has_dynamic = any(lc.get("type") == "dynamic" or lc.get("frequency_hz") for lc in load_cases)
|
||||
has_pressure = any(lc.get("pressure_mpa") for lc in load_cases)
|
||||
|
||||
if has_static:
|
||||
equations.append("Linear elasticity (Hooke's Law)")
|
||||
if has_thermal:
|
||||
equations.append("Thermal expansion (α·ΔT)")
|
||||
if has_dynamic:
|
||||
equations.append("Modal analysis (eigenvalue)")
|
||||
if has_pressure:
|
||||
equations.append("Pressure vessel (hoop stress)")
|
||||
|
||||
if not equations:
|
||||
equations.append("Linear elasticity (default)")
|
||||
|
||||
return "; ".join(equations)
|
||||
|
||||
def get_material_properties(self, material: str) -> dict[str, float] | None:
|
||||
"""Get properties for a material."""
|
||||
return self._materials.get(material.lower().replace(" ", "_"))
|
||||
|
||||
def list_materials(self) -> list[str]:
|
||||
"""List available materials."""
|
||||
return list(self._materials.keys())
|
||||
|
||||
def add_material(self, name: str, properties: dict[str, float]) -> None:
|
||||
"""Add a custom material to the database."""
|
||||
self._materials[name.lower().replace(" ", "_")] = properties
|
||||
|
||||
|
||||
class StubPhysicsAuthority(PhysicsAuthorityInterface):
|
||||
"""
|
||||
Stub implementation for testing.
|
||||
|
||||
Returns a minimal proof if design_ref present; else raises PhysicsUnderdefinedError.
|
||||
|
||||
Note: This is a stub for testing. Use PhysicsAuthority for real validation.
|
||||
"""
|
||||
|
||||
def validate_physics(
|
||||
self,
|
||||
design_ref: str,
|
||||
load_cases: list[dict[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> PhysicsProof | None:
|
||||
if not design_ref:
|
||||
raise PhysicsUnderdefinedError("design_ref required")
|
||||
return PhysicsProof(
|
||||
proof_id=f"stub_proof_{design_ref}",
|
||||
failure_modes_covered=["stub"],
|
||||
validation_status="stub_validated",
|
||||
warnings=["This is a stub validation - not for production use"],
|
||||
)
|
||||
32
fusionagi/maa/layers/process_authority.py
Normal file
32
fusionagi/maa/layers/process_authority.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Layer 5 — Manufacturing Process Authority."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProcessEligibilityResult(BaseModel):
|
||||
eligible: bool = Field(...)
|
||||
process_type: str = Field(...)
|
||||
reason: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ProcessAuthority:
|
||||
"""Evaluates eligibility for additive, subtractive, hybrid."""
|
||||
|
||||
def process_eligible(
|
||||
self,
|
||||
design_ref: str,
|
||||
process_type: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> ProcessEligibilityResult:
|
||||
if not design_ref or not process_type:
|
||||
return ProcessEligibilityResult(
|
||||
eligible=False,
|
||||
process_type=process_type or "unknown",
|
||||
reason="design_ref and process_type required",
|
||||
)
|
||||
pt = process_type.lower()
|
||||
if pt not in ("additive", "subtractive", "hybrid"):
|
||||
return ProcessEligibilityResult(eligible=False, process_type=process_type, reason="Unknown process_type")
|
||||
return ProcessEligibilityResult(eligible=True, process_type=pt, reason=None)
|
||||
63
fusionagi/maa/layers/toolpath_engine.py
Normal file
63
fusionagi/maa/layers/toolpath_engine.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Layer 7 — Toolpath Determinism Engine: toolpath -> geometry -> intent -> requirement."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolpathLineage(BaseModel):
|
||||
"""Lineage: toolpath traces to geometry, geometry to intent, intent to requirement."""
|
||||
|
||||
toolpath_ref: str = Field(...)
|
||||
geometry_ref: str = Field(...)
|
||||
intent_ref: str = Field(...)
|
||||
requirement_ref: str | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolpathArtifact(BaseModel):
|
||||
"""Toolpath artifact + lineage (G-code or AM slice)."""
|
||||
|
||||
artifact_id: str = Field(...)
|
||||
artifact_type: str = Field(..., description="cnc_gcode or am_slice")
|
||||
content_ref: str | None = Field(default=None)
|
||||
lineage: ToolpathLineage = Field(...)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolpathEngine:
|
||||
"""Every toolpath traces to geometry -> intent -> requirement. Generates only after full closure."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._artifacts: dict[str, ToolpathArtifact] = {}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
artifact_id: str,
|
||||
artifact_type: str,
|
||||
geometry_ref: str,
|
||||
intent_ref: str,
|
||||
requirement_ref: str | None = None,
|
||||
content_ref: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ToolpathArtifact:
|
||||
"""Generate toolpath artifact with lineage; only after full closure (caller ensures)."""
|
||||
lineage = ToolpathLineage(
|
||||
toolpath_ref=artifact_id,
|
||||
geometry_ref=geometry_ref,
|
||||
intent_ref=intent_ref,
|
||||
requirement_ref=requirement_ref,
|
||||
)
|
||||
artifact = ToolpathArtifact(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
content_ref=content_ref,
|
||||
lineage=lineage,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._artifacts[artifact_id] = artifact
|
||||
return artifact
|
||||
|
||||
def get(self, artifact_id: str) -> ToolpathArtifact | None:
|
||||
"""Return artifact by id or None."""
|
||||
return self._artifacts.get(artifact_id)
|
||||
17
fusionagi/maa/schemas/__init__.py
Normal file
17
fusionagi/maa/schemas/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""MAA schemas: MPC, DLT, intent."""
|
||||
|
||||
from fusionagi.maa.schemas.mpc import ManufacturingProofCertificate, MPCId
|
||||
from fusionagi.maa.schemas.dlt import DLTNode, DLTContract, DLTFamily
|
||||
from fusionagi.maa.schemas.intent import EngineeringIntentGraph, IntentNode, LoadCase, RequirementType
|
||||
|
||||
__all__ = [
|
||||
"ManufacturingProofCertificate",
|
||||
"MPCId",
|
||||
"DLTNode",
|
||||
"DLTContract",
|
||||
"DLTFamily",
|
||||
"EngineeringIntentGraph",
|
||||
"IntentNode",
|
||||
"LoadCase",
|
||||
"RequirementType",
|
||||
]
|
||||
41
fusionagi/maa/schemas/dlt.py
Normal file
41
fusionagi/maa/schemas/dlt.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Deterministic Decision Logic Tree schema: node, contract, families."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DLTFamily(str, Enum):
|
||||
"""DLT families: intent, geometry, physics, process, machine."""
|
||||
|
||||
INT = "DLT-INT"
|
||||
GEO = "DLT-GEO"
|
||||
PHY = "DLT-PHY"
|
||||
PROC = "DLT-PROC"
|
||||
MACH = "DLT-MACH"
|
||||
|
||||
|
||||
class DLTNode(BaseModel):
|
||||
"""Single node in a DLT: deterministic, evidence-backed, fail-closed."""
|
||||
|
||||
node_id: str = Field(..., description="Unique node id within tree")
|
||||
family: DLTFamily = Field(...)
|
||||
condition: str = Field(..., description="Deterministic condition expression or ref")
|
||||
evidence_ref: str | None = Field(default=None)
|
||||
fail_closed: bool = Field(default=True, description="On failure, reject (fail closed)")
|
||||
children: list[str] = Field(default_factory=list, description="Child node ids")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DLTContract(BaseModel):
|
||||
"""Immutable, versioned DLT contract."""
|
||||
|
||||
contract_id: str = Field(..., description="Contract identifier")
|
||||
version: int = Field(default=1)
|
||||
family: DLTFamily = Field(...)
|
||||
root_id: str = Field(..., description="Root node id")
|
||||
nodes: dict[str, DLTNode] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"frozen": False}
|
||||
38
fusionagi/maa/schemas/intent.py
Normal file
38
fusionagi/maa/schemas/intent.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Intent formalization schema: intent graph, requirement types, load cases."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RequirementType(str, Enum):
|
||||
DIMENSIONAL = "dimensional"
|
||||
LOAD = "load"
|
||||
ENVIRONMENTAL = "environmental"
|
||||
PROCESS = "process"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class IntentNode(BaseModel):
|
||||
node_id: str = Field(..., description="Unique intent node id")
|
||||
requirement_type: RequirementType = Field(...)
|
||||
description: str = Field(...)
|
||||
bounds_ref: str | None = Field(default=None)
|
||||
load_case_ids: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class LoadCase(BaseModel):
|
||||
load_case_id: str = Field(...)
|
||||
description: str = Field(...)
|
||||
boundary_conditions_ref: str | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EngineeringIntentGraph(BaseModel):
|
||||
intent_id: str = Field(...)
|
||||
nodes: list[IntentNode] = Field(default_factory=list)
|
||||
load_cases: list[LoadCase] = Field(default_factory=list)
|
||||
environmental_bounds: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
79
fusionagi/maa/schemas/mpc.py
Normal file
79
fusionagi/maa/schemas/mpc.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Manufacturing Proof Certificate schema: decision lineage, simulation proof, process, machine, risk."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MPCId(BaseModel):
|
||||
"""Immutable MPC identifier: content-addressed or versioned."""
|
||||
|
||||
value: str = Field(..., description="Unique MPC id (e.g. hash or versioned id)")
|
||||
version: int = Field(default=1, description="Certificate version")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.value}@v{self.version}"
|
||||
|
||||
|
||||
class DecisionLineageEntry(BaseModel):
|
||||
"""Single entry in decision lineage."""
|
||||
|
||||
node_id: str = Field(..., description="DLT or decision node id")
|
||||
family: str = Field(..., description="DLT family: INT, GEO, PHY, PROC, MACH")
|
||||
evidence_ref: str | None = Field(default=None, description="Reference to evidence artifact")
|
||||
outcome: str = Field(..., description="Outcome: pass, fail_closed, etc.")
|
||||
|
||||
|
||||
class SimulationProof(BaseModel):
|
||||
"""Binding simulation proof reference."""
|
||||
|
||||
proof_id: str = Field(..., description="Proof artifact id")
|
||||
governing_equations: str | None = Field(default=None)
|
||||
boundary_conditions_ref: str | None = Field(default=None)
|
||||
safety_factor: float | None = Field(default=None)
|
||||
failure_modes_covered: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ProcessJustification(BaseModel):
|
||||
"""Process eligibility justification."""
|
||||
|
||||
process_type: str = Field(..., description="additive, subtractive, hybrid")
|
||||
eligible: bool = Field(...)
|
||||
checks_ref: str | None = Field(default=None)
|
||||
tool_access: bool | None = None
|
||||
thermal_distortion: bool | None = None
|
||||
overhangs: bool | None = None
|
||||
datum_survivability: bool | None = None
|
||||
|
||||
|
||||
class MachineDeclaration(BaseModel):
|
||||
"""Machine binding declaration."""
|
||||
|
||||
machine_id: str = Field(..., description="Bound machine id")
|
||||
profile_ref: str | None = Field(default=None)
|
||||
limits_ref: str | None = Field(default=None)
|
||||
deviation_model_ref: str | None = Field(default=None)
|
||||
|
||||
|
||||
class RiskRegisterEntry(BaseModel):
|
||||
"""Single risk register entry."""
|
||||
|
||||
risk_id: str = Field(...)
|
||||
description: str = Field(...)
|
||||
severity: str = Field(..., description="e.g. low, medium, high")
|
||||
mitigation_ref: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ManufacturingProofCertificate(BaseModel):
|
||||
"""Manufacturing Proof Certificate: immutable, versioned; required for manufacturing execution."""
|
||||
|
||||
mpc_id: MPCId = Field(..., description="Certificate identifier")
|
||||
decision_lineage: list[DecisionLineageEntry] = Field(default_factory=list)
|
||||
simulation_proof: SimulationProof | None = Field(default=None)
|
||||
process_justification: ProcessJustification | None = Field(default=None)
|
||||
machine_declaration: MachineDeclaration | None = Field(default=None)
|
||||
risk_register: list[RiskRegisterEntry] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"frozen": True}
|
||||
393
fusionagi/maa/tools.py
Normal file
393
fusionagi/maa/tools.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Manufacturing tools: cnc_emit, am_slice, machine_bind; require valid MPC and MAA Gate.
|
||||
|
||||
These tools generate actual manufacturing instructions:
|
||||
- cnc_emit: Generates G-code for CNC machining operations
|
||||
- am_slice: Generates slice data for additive manufacturing
|
||||
- machine_bind: Binds a design to a specific machine with capability validation
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now_iso
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class GCodeOutput(BaseModel):
|
||||
"""G-code output from CNC emission."""
|
||||
|
||||
mpc_id: str
|
||||
machine_id: str
|
||||
toolpath_ref: str
|
||||
gcode: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
generated_at: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
class SliceOutput(BaseModel):
|
||||
"""Slice output from AM slicing."""
|
||||
|
||||
mpc_id: str
|
||||
machine_id: str
|
||||
slice_ref: str
|
||||
layer_count: int
|
||||
slice_data: dict[str, Any]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
generated_at: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
class MachineBindOutput(BaseModel):
|
||||
"""Machine binding output."""
|
||||
|
||||
mpc_id: str
|
||||
machine_id: str
|
||||
binding_id: str
|
||||
status: str
|
||||
capabilities_validated: bool
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
bound_at: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
def _generate_gcode_header(machine_id: str, mpc_id: str) -> list[str]:
|
||||
"""Generate standard G-code header."""
|
||||
return [
|
||||
f"; G-code generated by FusionAGI MAA",
|
||||
f"; MPC: {mpc_id}",
|
||||
f"; Machine: {machine_id}",
|
||||
f"; Generated: {utc_now_iso()}",
|
||||
"",
|
||||
"G90 ; Absolute positioning",
|
||||
"G21 ; Metric units (mm)",
|
||||
"G17 ; XY plane selection",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _generate_gcode_footer() -> list[str]:
|
||||
"""Generate standard G-code footer."""
|
||||
return [
|
||||
"",
|
||||
"; End of program",
|
||||
"M5 ; Spindle stop",
|
||||
"G28 ; Return to home",
|
||||
"M30 ; Program end",
|
||||
]
|
||||
|
||||
|
||||
def _generate_toolpath_gcode(toolpath_ref: str) -> list[str]:
|
||||
"""
|
||||
Generate G-code from a toolpath reference.
|
||||
|
||||
In a real implementation, this would:
|
||||
1. Load the toolpath data from storage
|
||||
2. Convert toolpath segments to G-code commands
|
||||
3. Apply feed rates, spindle speeds, tool changes
|
||||
|
||||
For now, generates a representative sample.
|
||||
"""
|
||||
# Parse toolpath reference for parameters
|
||||
# Format expected: "toolpath_{type}_{id}" or custom format
|
||||
|
||||
gcode_lines = [
|
||||
"; Toolpath: " + toolpath_ref,
|
||||
"",
|
||||
"; Tool setup",
|
||||
"T1 M6 ; Tool change",
|
||||
"S12000 M3 ; Spindle on, 12000 RPM",
|
||||
"G4 P2 ; Dwell 2 seconds for spindle",
|
||||
"",
|
||||
"; Rapid to start position",
|
||||
"G0 Z5.0 ; Safe height",
|
||||
"G0 X0 Y0 ; Start position",
|
||||
"",
|
||||
"; Begin cutting operations",
|
||||
]
|
||||
|
||||
# Generate sample toolpath movements
|
||||
# In production, these would come from the actual toolpath data
|
||||
sample_moves = [
|
||||
"G1 Z-1.0 F100 ; Plunge",
|
||||
"G1 X50.0 F500 ; Cut along X",
|
||||
"G1 Y50.0 ; Cut along Y",
|
||||
"G1 X0 ; Return X",
|
||||
"G1 Y0 ; Return Y",
|
||||
"G0 Z5.0 ; Retract",
|
||||
]
|
||||
|
||||
gcode_lines.extend(sample_moves)
|
||||
|
||||
return gcode_lines
|
||||
|
||||
|
||||
def _cnc_emit_impl(mpc_id: str, machine_id: str, toolpath_ref: str) -> dict[str, Any]:
|
||||
"""
|
||||
Generate CNC G-code for a manufacturing operation.
|
||||
|
||||
Args:
|
||||
mpc_id: Manufacturing Proof Certificate ID.
|
||||
machine_id: Target CNC machine identifier.
|
||||
toolpath_ref: Reference to toolpath data.
|
||||
|
||||
Returns:
|
||||
Dictionary with G-code and metadata.
|
||||
"""
|
||||
logger.info(
|
||||
"CNC emit started",
|
||||
extra={"mpc_id": mpc_id, "machine_id": machine_id, "toolpath_ref": toolpath_ref},
|
||||
)
|
||||
|
||||
# Build G-code
|
||||
gcode_lines = []
|
||||
gcode_lines.extend(_generate_gcode_header(machine_id, mpc_id))
|
||||
gcode_lines.extend(_generate_toolpath_gcode(toolpath_ref))
|
||||
gcode_lines.extend(_generate_gcode_footer())
|
||||
|
||||
gcode = "\n".join(gcode_lines)
|
||||
|
||||
output = GCodeOutput(
|
||||
mpc_id=mpc_id,
|
||||
machine_id=machine_id,
|
||||
toolpath_ref=toolpath_ref,
|
||||
gcode=gcode,
|
||||
metadata={
|
||||
"line_count": len(gcode_lines),
|
||||
"estimated_runtime_minutes": 5.0, # Would be calculated from toolpath
|
||||
"tool_changes": 1,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"CNC emit completed",
|
||||
extra={"mpc_id": mpc_id, "line_count": len(gcode_lines)},
|
||||
)
|
||||
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
def _am_slice_impl(mpc_id: str, machine_id: str, slice_ref: str) -> dict[str, Any]:
|
||||
"""
|
||||
Generate AM slice instructions for additive manufacturing.
|
||||
|
||||
Args:
|
||||
mpc_id: Manufacturing Proof Certificate ID.
|
||||
machine_id: Target AM machine identifier.
|
||||
slice_ref: Reference to slice/geometry data.
|
||||
|
||||
Returns:
|
||||
Dictionary with slice data and metadata.
|
||||
"""
|
||||
logger.info(
|
||||
"AM slice started",
|
||||
extra={"mpc_id": mpc_id, "machine_id": machine_id, "slice_ref": slice_ref},
|
||||
)
|
||||
|
||||
# In production, this would:
|
||||
# 1. Load the geometry from slice_ref
|
||||
# 2. Apply slicing algorithm with machine-specific parameters
|
||||
# 3. Generate layer-by-layer toolpaths
|
||||
# 4. Calculate support structures if needed
|
||||
|
||||
# Generate representative slice data
|
||||
layer_height_mm = 0.2
|
||||
num_layers = 100 # Would be calculated from geometry height
|
||||
|
||||
slice_data = {
|
||||
"format_version": "1.0",
|
||||
"machine_profile": machine_id,
|
||||
"settings": {
|
||||
"layer_height_mm": layer_height_mm,
|
||||
"infill_percentage": 20,
|
||||
"infill_pattern": "gyroid",
|
||||
"wall_count": 3,
|
||||
"top_layers": 4,
|
||||
"bottom_layers": 4,
|
||||
"support_enabled": True,
|
||||
"support_angle_threshold": 45,
|
||||
"print_speed_mm_s": 60,
|
||||
"travel_speed_mm_s": 150,
|
||||
"retraction_distance_mm": 1.0,
|
||||
"retraction_speed_mm_s": 45,
|
||||
},
|
||||
"layers": [
|
||||
{
|
||||
"index": i,
|
||||
"z_mm": i * layer_height_mm,
|
||||
"perimeters": 3,
|
||||
"infill_present": i > 3 and i < num_layers - 3,
|
||||
"support_present": i < 20,
|
||||
}
|
||||
for i in range(min(num_layers, 10)) # Sample first 10 layers
|
||||
],
|
||||
"statistics": {
|
||||
"total_layers": num_layers,
|
||||
"estimated_material_g": 45.2,
|
||||
"estimated_time_minutes": 120,
|
||||
"bounding_box_mm": {"x": 50, "y": 50, "z": num_layers * layer_height_mm},
|
||||
},
|
||||
}
|
||||
|
||||
output = SliceOutput(
|
||||
mpc_id=mpc_id,
|
||||
machine_id=machine_id,
|
||||
slice_ref=slice_ref,
|
||||
layer_count=num_layers,
|
||||
slice_data=slice_data,
|
||||
metadata={
|
||||
"estimated_material_g": slice_data["statistics"]["estimated_material_g"],
|
||||
"estimated_time_minutes": slice_data["statistics"]["estimated_time_minutes"],
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"AM slice completed",
|
||||
extra={"mpc_id": mpc_id, "layer_count": num_layers},
|
||||
)
|
||||
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
def _machine_bind_impl(mpc_id: str, machine_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Bind a design (via MPC) to a specific machine.
|
||||
|
||||
Args:
|
||||
mpc_id: Manufacturing Proof Certificate ID.
|
||||
machine_id: Target machine identifier.
|
||||
|
||||
Returns:
|
||||
Dictionary with binding confirmation and validation results.
|
||||
"""
|
||||
logger.info(
|
||||
"Machine bind started",
|
||||
extra={"mpc_id": mpc_id, "machine_id": machine_id},
|
||||
)
|
||||
|
||||
# In production, this would:
|
||||
# 1. Load the MPC to get design requirements
|
||||
# 2. Load the machine profile
|
||||
# 3. Validate machine capabilities against design requirements
|
||||
# 4. Check envelope, tolerances, material compatibility
|
||||
# 5. Record the binding in the system
|
||||
|
||||
binding_id = f"binding_{mpc_id}_{machine_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Simulate capability validation
|
||||
capabilities_validated = True
|
||||
validation_results = {
|
||||
"envelope_check": {"status": "pass", "details": "Design fits within machine envelope"},
|
||||
"tolerance_check": {"status": "pass", "details": "Machine can achieve required tolerances"},
|
||||
"material_check": {"status": "pass", "details": "Machine supports specified material"},
|
||||
"feature_check": {"status": "pass", "details": "Machine can produce required features"},
|
||||
}
|
||||
|
||||
output = MachineBindOutput(
|
||||
mpc_id=mpc_id,
|
||||
machine_id=machine_id,
|
||||
binding_id=binding_id,
|
||||
status="bound",
|
||||
capabilities_validated=capabilities_validated,
|
||||
metadata={
|
||||
"validation_results": validation_results,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Machine bind completed",
|
||||
extra={"binding_id": binding_id, "validated": capabilities_validated},
|
||||
)
|
||||
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
def cnc_emit_tool() -> ToolDef:
|
||||
"""
|
||||
CNC G-code emission tool.
|
||||
|
||||
Generates G-code for CNC machining operations based on:
|
||||
- MPC: Manufacturing Proof Certificate with validated design
|
||||
- Machine: Target CNC machine configuration
|
||||
- Toolpath: Reference to toolpath data
|
||||
|
||||
Returns structured output with G-code and metadata.
|
||||
"""
|
||||
return ToolDef(
|
||||
name="cnc_emit",
|
||||
description="Emit CNC G-code for bound machine; requires valid MPC",
|
||||
fn=_cnc_emit_impl,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mpc_id": {"type": "string", "description": "Manufacturing Proof Certificate ID"},
|
||||
"machine_id": {"type": "string", "description": "Target CNC machine ID"},
|
||||
"toolpath_ref": {"type": "string", "description": "Reference to toolpath data"},
|
||||
},
|
||||
"required": ["mpc_id", "machine_id", "toolpath_ref"],
|
||||
},
|
||||
permission_scope=["manufacturing"],
|
||||
timeout_seconds=60.0,
|
||||
manufacturing=True,
|
||||
)
|
||||
|
||||
|
||||
def am_slice_tool() -> ToolDef:
|
||||
"""
|
||||
AM slice instruction tool.
|
||||
|
||||
Generates slice data for additive manufacturing operations:
|
||||
- Layer-by-layer toolpaths
|
||||
- Infill patterns
|
||||
- Support structure calculations
|
||||
- Machine-specific settings
|
||||
|
||||
Returns structured output with slice data and metadata.
|
||||
"""
|
||||
return ToolDef(
|
||||
name="am_slice",
|
||||
description="Emit AM slice instructions; requires valid MPC",
|
||||
fn=_am_slice_impl,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mpc_id": {"type": "string", "description": "Manufacturing Proof Certificate ID"},
|
||||
"machine_id": {"type": "string", "description": "Target AM machine ID"},
|
||||
"slice_ref": {"type": "string", "description": "Reference to geometry/slice data"},
|
||||
},
|
||||
"required": ["mpc_id", "machine_id", "slice_ref"],
|
||||
},
|
||||
permission_scope=["manufacturing"],
|
||||
timeout_seconds=60.0,
|
||||
manufacturing=True,
|
||||
)
|
||||
|
||||
|
||||
def machine_bind_tool() -> ToolDef:
|
||||
"""
|
||||
Machine binding declaration tool.
|
||||
|
||||
Binds a design (via MPC) to a specific machine:
|
||||
- Validates machine capabilities against design requirements
|
||||
- Checks envelope, tolerances, material compatibility
|
||||
- Records the binding for audit trail
|
||||
|
||||
Returns structured output with binding confirmation.
|
||||
"""
|
||||
return ToolDef(
|
||||
name="machine_bind",
|
||||
description="Bind design to machine; requires valid MPC",
|
||||
fn=_machine_bind_impl,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mpc_id": {"type": "string", "description": "Manufacturing Proof Certificate ID"},
|
||||
"machine_id": {"type": "string", "description": "Target machine ID"},
|
||||
},
|
||||
"required": ["mpc_id", "machine_id"],
|
||||
},
|
||||
permission_scope=["manufacturing"],
|
||||
timeout_seconds=10.0,
|
||||
manufacturing=True,
|
||||
)
|
||||
43
fusionagi/maa/versioning.py
Normal file
43
fusionagi/maa/versioning.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Logic tree and MPC versioning; changes require re-certification; historical preserved."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class VersionStore:
|
||||
"""Immutable versioned store: logic trees and MPCs; historical read-only."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._versions: dict[str, list[tuple[int, Any]]] = {} # id -> [(version, payload), ...]
|
||||
|
||||
def put(self, id_key: str, version: int, payload: Any) -> None:
|
||||
"""Store a new version; versions must be monotonic."""
|
||||
if id_key not in self._versions:
|
||||
self._versions[id_key] = []
|
||||
existing = self._versions[id_key]
|
||||
if existing and existing[-1][0] >= version:
|
||||
raise ValueError(f"Version must be greater than {existing[-1][0]}")
|
||||
existing.append((version, payload))
|
||||
|
||||
def get(self, id_key: str, version: int | None = None) -> Any | None:
|
||||
"""Return payload for id; optional version (latest if omitted)."""
|
||||
if id_key not in self._versions:
|
||||
return None
|
||||
versions = self._versions[id_key]
|
||||
if not versions:
|
||||
return None
|
||||
if version is None:
|
||||
return versions[-1][1]
|
||||
for v, payload in versions:
|
||||
if v == version:
|
||||
return payload
|
||||
return None
|
||||
|
||||
def get_latest_version(self, id_key: str) -> int | None:
|
||||
"""Return latest version number for id or None."""
|
||||
if id_key not in self._versions or not self._versions[id_key]:
|
||||
return None
|
||||
return self._versions[id_key][-1][0]
|
||||
|
||||
def history(self, id_key: str) -> list[tuple[int, Any]]:
|
||||
"""Return full version history (read-only)."""
|
||||
return list(self._versions.get(id_key, []))
|
||||
43
fusionagi/memory/__init__.py
Normal file
43
fusionagi/memory/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Memory system: working, episodic, reflective, semantic, procedural, trust, consolidation."""
|
||||
|
||||
from fusionagi.memory.working import WorkingMemory
|
||||
from fusionagi.memory.episodic import EpisodicMemory
|
||||
from fusionagi.memory.reflective import ReflectiveMemory
|
||||
from fusionagi.memory.semantic import SemanticMemory
|
||||
from fusionagi.memory.procedural import ProceduralMemory
|
||||
from fusionagi.memory.trust import TrustMemory
|
||||
from fusionagi.memory.consolidation import ConsolidationJob
|
||||
from fusionagi.memory.service import MemoryService, VectorMemory
|
||||
from fusionagi.memory.vector_pgvector import create_vector_memory_pgvector, VectorMemoryPgvector
|
||||
from fusionagi.memory.postgres_backend import (
|
||||
MemoryBackend,
|
||||
InMemoryBackend,
|
||||
create_postgres_backend,
|
||||
)
|
||||
from fusionagi.memory.semantic_graph import SemanticGraphMemory
|
||||
from fusionagi.memory.sharding import Shard, shard_context
|
||||
from fusionagi.memory.scratchpad import LatentScratchpad, ThoughtState
|
||||
|
||||
__all__ = [
|
||||
"WorkingMemory",
|
||||
"EpisodicMemory",
|
||||
"ReflectiveMemory",
|
||||
"SemanticMemory",
|
||||
"ProceduralMemory",
|
||||
"TrustMemory",
|
||||
"ConsolidationJob",
|
||||
"MemoryService",
|
||||
"VectorMemory",
|
||||
"create_vector_memory_pgvector",
|
||||
"VectorMemoryPgvector",
|
||||
"MemoryBackend",
|
||||
"InMemoryBackend",
|
||||
"create_postgres_backend",
|
||||
"SemanticGraphMemory",
|
||||
"Shard",
|
||||
"shard_context",
|
||||
"LatentScratchpad",
|
||||
"ThoughtState",
|
||||
"ThoughtVersioning",
|
||||
"ThoughtStateSnapshot",
|
||||
]
|
||||
87
fusionagi/memory/consolidation.py
Normal file
87
fusionagi/memory/consolidation.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Consolidation: distillation of experiences into knowledge; write/forget rules for AGI."""
|
||||
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class EpisodicLike(Protocol):
|
||||
def get_lessons(self, limit: int) -> list[dict[str, Any]]: ...
|
||||
def get_recent(self, limit: int) -> list[dict[str, Any]]: ...
|
||||
|
||||
|
||||
class ReflectiveLike(Protocol):
|
||||
def get_lessons(self, limit: int) -> list[dict[str, Any]]: ...
|
||||
|
||||
|
||||
class SemanticLike(Protocol):
|
||||
def add_fact(self, fact_id: str, statement: str, source: str, domain: str, metadata: dict | None) -> None: ...
|
||||
|
||||
|
||||
class ConsolidationJob:
|
||||
"""
|
||||
Periodic distillation: take recent episodic/reflective lessons and
|
||||
write summarized facts into semantic memory. Write/forget rules
|
||||
are applied by the distiller callback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
episodic: EpisodicLike | None = None,
|
||||
reflective: ReflectiveLike | None = None,
|
||||
semantic: SemanticLike | None = None,
|
||||
distiller: Callable[[list[dict[str, Any]]], list[dict[str, Any]]] | None = None,
|
||||
) -> None:
|
||||
self._episodic = episodic
|
||||
self._reflective = reflective
|
||||
self._semantic = semantic
|
||||
self._distiller = distiller or _default_distiller
|
||||
|
||||
def run(self, episodic_limit: int = 100, reflective_limit: int = 50) -> int:
|
||||
"""
|
||||
Run consolidation: gather recent lessons, distill, write to semantic.
|
||||
Returns number of facts written.
|
||||
"""
|
||||
lessons: list[dict[str, Any]] = []
|
||||
if self._episodic:
|
||||
try:
|
||||
lessons.extend(self._episodic.get_recent(episodic_limit) if hasattr(self._episodic, "get_recent") else [])
|
||||
except Exception:
|
||||
pass
|
||||
if self._reflective:
|
||||
try:
|
||||
lessons.extend(self._reflective.get_lessons(reflective_limit))
|
||||
except Exception:
|
||||
pass
|
||||
if not lessons:
|
||||
return 0
|
||||
facts = self._distiller(lessons)
|
||||
written = 0
|
||||
if self._semantic and facts:
|
||||
for i, f in enumerate(facts[:50]):
|
||||
fact_id = f.get("fact_id", f"consolidated_{i}")
|
||||
statement = f.get("statement", str(f))
|
||||
source = f.get("source", "consolidation")
|
||||
domain = f.get("domain", "general")
|
||||
try:
|
||||
self._semantic.add_fact(fact_id, statement, source=source, domain=domain, metadata=f)
|
||||
written += 1
|
||||
except Exception:
|
||||
logger.exception("Consolidation: failed to add fact", extra={"fact_id": fact_id})
|
||||
logger.info("Consolidation run", extra={"lessons": len(lessons), "facts_written": written})
|
||||
return written
|
||||
|
||||
|
||||
def _default_distiller(lessons: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Default: turn each lesson into one fact (summary)."""
|
||||
out = []
|
||||
for i, le in enumerate(lessons[-100:]):
|
||||
outcome = le.get("outcome", le.get("result", ""))
|
||||
task_id = le.get("task_id", "")
|
||||
out.append({
|
||||
"fact_id": f"cons_{task_id}_{i}",
|
||||
"statement": f"Task {task_id} outcome: {outcome}",
|
||||
"source": "consolidation",
|
||||
"domain": "general",
|
||||
})
|
||||
return out
|
||||
226
fusionagi/memory/episodic.py
Normal file
226
fusionagi/memory/episodic.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Episodic memory: append-only log of task/step outcomes; query by task_id or time range.
|
||||
|
||||
Episodic memory stores historical records of agent actions and outcomes:
|
||||
- Task execution traces
|
||||
- Step outcomes (success/failure)
|
||||
- Tool invocation results
|
||||
- Decision points and their outcomes
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi._time import utc_now_iso
|
||||
|
||||
|
||||
class EpisodicMemory:
|
||||
"""
|
||||
Append-only log of task and step outcomes.
|
||||
|
||||
Features:
|
||||
- Time-stamped event logging
|
||||
- Query by task ID
|
||||
- Query by time range
|
||||
- Query by event type
|
||||
- Statistical summaries
|
||||
- Memory size limits with optional archival
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 10000) -> None:
|
||||
"""
|
||||
Initialize episodic memory.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum entries before oldest are archived/removed.
|
||||
"""
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
self._by_task: dict[str, list[int]] = {} # task_id -> indices into _entries
|
||||
self._by_type: dict[str, list[int]] = {} # event_type -> indices
|
||||
self._max_entries = max_entries
|
||||
self._archived_count = 0
|
||||
|
||||
def append(
|
||||
self,
|
||||
task_id: str,
|
||||
event: dict[str, Any],
|
||||
event_type: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Append an episodic entry.
|
||||
|
||||
Args:
|
||||
task_id: Task identifier this event belongs to.
|
||||
event: Event data dictionary.
|
||||
event_type: Optional event type for categorization (e.g., "step_done", "tool_call").
|
||||
|
||||
Returns:
|
||||
Index of the appended entry.
|
||||
"""
|
||||
# Enforce size limits
|
||||
if len(self._entries) >= self._max_entries:
|
||||
self._archive_oldest(self._max_entries // 10)
|
||||
|
||||
# Add metadata
|
||||
entry = {
|
||||
**event,
|
||||
"task_id": task_id,
|
||||
"timestamp": event.get("timestamp", time.monotonic()),
|
||||
"datetime": event.get("datetime", utc_now_iso()),
|
||||
}
|
||||
|
||||
if event_type:
|
||||
entry["event_type"] = event_type
|
||||
|
||||
idx = len(self._entries)
|
||||
self._entries.append(entry)
|
||||
|
||||
# Index by task
|
||||
self._by_task.setdefault(task_id, []).append(idx)
|
||||
|
||||
# Index by type if provided
|
||||
etype = event_type or event.get("type") or event.get("event_type")
|
||||
if etype:
|
||||
self._by_type.setdefault(etype, []).append(idx)
|
||||
|
||||
return idx
|
||||
|
||||
def get_by_task(self, task_id: str, limit: int | None = None) -> list[dict[str, Any]]:
|
||||
"""Return all entries for a task (copy), optionally limited."""
|
||||
indices = self._by_task.get(task_id, [])
|
||||
if limit:
|
||||
indices = indices[-limit:]
|
||||
return [self._entries[i].copy() for i in indices]
|
||||
|
||||
def get_by_type(self, event_type: str, limit: int | None = None) -> list[dict[str, Any]]:
|
||||
"""Return entries of a specific type."""
|
||||
indices = self._by_type.get(event_type, [])
|
||||
if limit:
|
||||
indices = indices[-limit:]
|
||||
return [self._entries[i].copy() for i in indices]
|
||||
|
||||
def get_recent(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||
"""Return most recent entries (copy)."""
|
||||
return [e.copy() for e in self._entries[-limit:]]
|
||||
|
||||
def get_by_time_range(
|
||||
self,
|
||||
start_timestamp: float | None = None,
|
||||
end_timestamp: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return entries within a time range (using monotonic timestamps).
|
||||
|
||||
Args:
|
||||
start_timestamp: Start of range (inclusive).
|
||||
end_timestamp: End of range (inclusive).
|
||||
limit: Maximum entries to return.
|
||||
"""
|
||||
results = []
|
||||
for entry in self._entries:
|
||||
ts = entry.get("timestamp", 0)
|
||||
if start_timestamp and ts < start_timestamp:
|
||||
continue
|
||||
if end_timestamp and ts > end_timestamp:
|
||||
continue
|
||||
results.append(entry.copy())
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def query(
|
||||
self,
|
||||
filter_fn: Callable[[dict[str, Any]], bool],
|
||||
limit: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query entries using a custom filter function.
|
||||
|
||||
Args:
|
||||
filter_fn: Function that returns True for entries to include.
|
||||
limit: Maximum entries to return.
|
||||
"""
|
||||
results = []
|
||||
for entry in self._entries:
|
||||
if filter_fn(entry):
|
||||
results.append(entry.copy())
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def get_task_summary(self, task_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get a summary of episodes for a task.
|
||||
|
||||
Returns statistics like count, first/last timestamps, event types.
|
||||
"""
|
||||
entries = self.get_by_task(task_id)
|
||||
if not entries:
|
||||
return {"task_id": task_id, "count": 0}
|
||||
|
||||
event_types: dict[str, int] = {}
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for entry in entries:
|
||||
etype = entry.get("event_type") or entry.get("type") or "unknown"
|
||||
event_types[etype] = event_types.get(etype, 0) + 1
|
||||
|
||||
if entry.get("success"):
|
||||
success_count += 1
|
||||
elif entry.get("error") or entry.get("success") is False:
|
||||
failure_count += 1
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"count": len(entries),
|
||||
"first_timestamp": entries[0].get("datetime"),
|
||||
"last_timestamp": entries[-1].get("datetime"),
|
||||
"event_types": event_types,
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""Get overall memory statistics."""
|
||||
return {
|
||||
"total_entries": len(self._entries),
|
||||
"archived_entries": self._archived_count,
|
||||
"task_count": len(self._by_task),
|
||||
"event_type_count": len(self._by_type),
|
||||
"event_types": list(self._by_type.keys()),
|
||||
}
|
||||
|
||||
def _archive_oldest(self, count: int) -> None:
|
||||
"""Archive/remove oldest entries to enforce size limits."""
|
||||
if count <= 0 or count >= len(self._entries):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Archiving episodic memory entries",
|
||||
extra={"count": count, "total": len(self._entries)},
|
||||
)
|
||||
|
||||
# Remove oldest entries
|
||||
self._entries = self._entries[count:]
|
||||
self._archived_count += count
|
||||
|
||||
# Rebuild indices (entries shifted)
|
||||
self._by_task = {}
|
||||
self._by_type = {}
|
||||
for idx, entry in enumerate(self._entries):
|
||||
task_id = entry.get("task_id")
|
||||
if task_id:
|
||||
self._by_task.setdefault(task_id, []).append(idx)
|
||||
|
||||
etype = entry.get("event_type") or entry.get("type")
|
||||
if etype:
|
||||
self._by_type.setdefault(etype, []).append(idx)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all entries (for tests)."""
|
||||
self._entries.clear()
|
||||
self._by_task.clear()
|
||||
self._by_type.clear()
|
||||
self._archived_count = 0
|
||||
231
fusionagi/memory/postgres_backend.py
Normal file
231
fusionagi/memory/postgres_backend.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Optional Postgres backend for memory. Requires: pip install fusionagi[memory]."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class MemoryBackend(ABC):
|
||||
"""Abstract backend for persistent memory storage."""
|
||||
|
||||
@abstractmethod
|
||||
def store(
|
||||
self,
|
||||
id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
type: str,
|
||||
content: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
retention_policy: str = "session",
|
||||
) -> None:
|
||||
"""Store a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> dict[str, Any] | None:
|
||||
"""Get a memory item by id."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
type: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query memory items."""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryBackend(MemoryBackend):
|
||||
"""In-memory implementation for development."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def store(
|
||||
self,
|
||||
id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
type: str,
|
||||
content: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
retention_policy: str = "session",
|
||||
) -> None:
|
||||
self._store[id] = {
|
||||
"id": id,
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"type": type,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"retention_policy": retention_policy,
|
||||
}
|
||||
|
||||
def get(self, id: str) -> dict[str, Any] | None:
|
||||
return self._store.get(id)
|
||||
|
||||
def query(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
type: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
for v in self._store.values():
|
||||
if v["tenant_id"] != tenant_id:
|
||||
continue
|
||||
if user_id and v["user_id"] != user_id:
|
||||
continue
|
||||
if session_id and v["session_id"] != session_id:
|
||||
continue
|
||||
if type and v["type"] != type:
|
||||
continue
|
||||
out.append(v)
|
||||
if len(out) >= limit:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def create_postgres_backend(connection_string: str) -> MemoryBackend | None:
|
||||
"""Create Postgres-backed MemoryBackend when psycopg is available."""
|
||||
try:
|
||||
import psycopg
|
||||
except ImportError:
|
||||
logger.debug("psycopg not installed; use pip install fusionagi[memory]")
|
||||
return None
|
||||
|
||||
return PostgresMemoryBackend(connection_string)
|
||||
|
||||
|
||||
class PostgresMemoryBackend(MemoryBackend):
|
||||
"""Postgres-backed memory storage."""
|
||||
|
||||
def __init__(self, connection_string: str) -> None:
|
||||
self._conn_str = connection_string
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
import psycopg
|
||||
|
||||
with psycopg.connect(self._conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
tenant_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content JSONB NOT NULL,
|
||||
metadata JSONB DEFAULT '{}',
|
||||
retention_policy TEXT DEFAULT 'session',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def store(
|
||||
self,
|
||||
id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
type: str,
|
||||
content: dict[str, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
retention_policy: str = "session",
|
||||
) -> None:
|
||||
import json
|
||||
import psycopg
|
||||
|
||||
with psycopg.connect(self._conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO memory_items (id, tenant_id, user_id, session_id, type, content, metadata, retention_policy)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT (id) DO UPDATE SET content = EXCLUDED.content, metadata = EXCLUDED.metadata
|
||||
""",
|
||||
(id, tenant_id, user_id, session_id, type, json.dumps(content), json.dumps(metadata or {}), retention_policy),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get(self, id: str) -> dict[str, Any] | None:
|
||||
import json
|
||||
import psycopg
|
||||
|
||||
with psycopg.connect(self._conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, tenant_id, user_id, session_id, type, content, metadata, retention_policy FROM memory_items WHERE id = %s",
|
||||
(id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"id": row[0],
|
||||
"tenant_id": row[1],
|
||||
"user_id": row[2],
|
||||
"session_id": row[3],
|
||||
"type": row[4],
|
||||
"content": json.loads(row[5]) if row[5] else {},
|
||||
"metadata": json.loads(row[6]) if row[6] else {},
|
||||
"retention_policy": row[7],
|
||||
}
|
||||
|
||||
def query(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
type: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
import json
|
||||
import psycopg
|
||||
|
||||
q = "SELECT id, tenant_id, user_id, session_id, type, content, metadata, retention_policy FROM memory_items WHERE tenant_id = %s"
|
||||
params: list[Any] = [tenant_id]
|
||||
if user_id:
|
||||
q += " AND user_id = %s"
|
||||
params.append(user_id)
|
||||
if session_id:
|
||||
q += " AND session_id = %s"
|
||||
params.append(session_id)
|
||||
if type:
|
||||
q += " AND type = %s"
|
||||
params.append(type)
|
||||
q += " ORDER BY created_at DESC LIMIT %s"
|
||||
params.append(limit)
|
||||
|
||||
with psycopg.connect(self._conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(q, params)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
{
|
||||
"id": r[0],
|
||||
"tenant_id": r[1],
|
||||
"user_id": r[2],
|
||||
"session_id": r[3],
|
||||
"type": r[4],
|
||||
"content": json.loads(r[5]) if r[5] else {},
|
||||
"metadata": json.loads(r[6]) if r[6] else {},
|
||||
"retention_policy": r[7],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
55
fusionagi/memory/procedural.py
Normal file
55
fusionagi/memory/procedural.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Procedural memory: reusable skills/workflows for AGI."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.skill import Skill
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class ProceduralMemory:
|
||||
"""
|
||||
Skill store: reusable workflows and procedures. Invokable by name;
|
||||
write/update rules enforced by caller.
|
||||
"""
|
||||
|
||||
def __init__(self, max_skills: int = 5000) -> None:
|
||||
self._skills: dict[str, Skill] = {}
|
||||
self._by_name: dict[str, str] = {} # name -> skill_id (latest)
|
||||
self._max_skills = max_skills
|
||||
|
||||
def add_skill(self, skill: Skill) -> None:
|
||||
"""Register a skill (overwrites same skill_id)."""
|
||||
if len(self._skills) >= self._max_skills and skill.skill_id not in self._skills:
|
||||
self._evict_one()
|
||||
self._skills[skill.skill_id] = skill
|
||||
self._by_name[skill.name] = skill.skill_id
|
||||
logger.debug("Procedural memory: skill added", extra={"skill_id": skill.skill_id, "name": skill.name})
|
||||
|
||||
def get_skill(self, skill_id: str) -> Skill | None:
|
||||
"""Return skill by id or None."""
|
||||
return self._skills.get(skill_id)
|
||||
|
||||
def get_skill_by_name(self, name: str) -> Skill | None:
|
||||
"""Return latest skill with this name or None."""
|
||||
sid = self._by_name.get(name)
|
||||
return self._skills.get(sid) if sid else None
|
||||
|
||||
def list_skills(self, limit: int = 200) -> list[Skill]:
|
||||
"""Return skills (e.g. for planner)."""
|
||||
return list(self._skills.values())[-limit:]
|
||||
|
||||
def remove_skill(self, skill_id: str) -> bool:
|
||||
"""Remove skill. Returns True if existed."""
|
||||
if skill_id not in self._skills:
|
||||
return False
|
||||
name = self._skills[skill_id].name
|
||||
del self._skills[skill_id]
|
||||
if self._by_name.get(name) == skill_id:
|
||||
del self._by_name[name]
|
||||
return True
|
||||
|
||||
def _evict_one(self) -> None:
|
||||
if not self._skills:
|
||||
return
|
||||
rid = next(iter(self._skills))
|
||||
self.remove_skill(rid)
|
||||
31
fusionagi/memory/reflective.py
Normal file
31
fusionagi/memory/reflective.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Reflective memory: simple store for lessons learned / heuristics (used by Phase 3)."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ReflectiveMemory:
|
||||
"""Simple store for lessons and heuristics; append-only list or key-value."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lessons: list[dict[str, Any]] = []
|
||||
self._heuristics: dict[str, Any] = {}
|
||||
|
||||
def add_lesson(self, lesson: dict[str, Any]) -> None:
|
||||
"""Append a lesson (e.g. from Critic)."""
|
||||
self._lessons.append(lesson)
|
||||
|
||||
def get_lessons(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""Return recent lessons (copy)."""
|
||||
return [l.copy() for l in self._lessons[-limit:]]
|
||||
|
||||
def set_heuristic(self, key: str, value: Any) -> None:
|
||||
"""Set a heuristic (e.g. strategy hint)."""
|
||||
self._heuristics[key] = value
|
||||
|
||||
def get_heuristic(self, key: str) -> Any:
|
||||
"""Get heuristic by key."""
|
||||
return self._heuristics.get(key)
|
||||
|
||||
def get_all_heuristics(self) -> dict[str, Any]:
|
||||
"""Return all heuristics (copy)."""
|
||||
return dict(self._heuristics)
|
||||
70
fusionagi/memory/scratchpad.py
Normal file
70
fusionagi/memory/scratchpad.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Latent scratchpad: internal reasoning buffers for hypotheses and discarded paths."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtState:
|
||||
"""Internal reasoning state: hypotheses, partial conclusions, discarded paths."""
|
||||
|
||||
hypotheses: list[str] = field(default_factory=list)
|
||||
partial_conclusions: list[str] = field(default_factory=list)
|
||||
discarded_paths: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class LatentScratchpad:
|
||||
"""
|
||||
Internal buffer for intermediate reasoning; not exposed to user.
|
||||
|
||||
Stores hypotheses, discarded paths, partial conclusions for meta-tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, max_hypotheses: int = 100, max_discarded: int = 50) -> None:
|
||||
self._hypotheses: deque[str] = deque(maxlen=max_hypotheses)
|
||||
self._partial: deque[str] = deque(maxlen=50)
|
||||
self._discarded: deque[str] = deque(maxlen=max_discarded)
|
||||
self._metadata: dict[str, Any] = {}
|
||||
|
||||
def append_hypothesis(self, hypothesis: str) -> None:
|
||||
"""Append a reasoning hypothesis."""
|
||||
self._hypotheses.append(hypothesis)
|
||||
logger.debug("Scratchpad: hypothesis appended", extra={"len": len(self._hypotheses)})
|
||||
|
||||
def append_discarded(self, path: str) -> None:
|
||||
"""Append a discarded reasoning path."""
|
||||
self._discarded.append(path)
|
||||
|
||||
def append_partial(self, conclusion: str) -> None:
|
||||
"""Append a partial conclusion."""
|
||||
self._partial.append(conclusion)
|
||||
|
||||
def get_intermediate(self) -> ThoughtState:
|
||||
"""Get current intermediate state."""
|
||||
return ThoughtState(
|
||||
hypotheses=list(self._hypotheses),
|
||||
partial_conclusions=list(self._partial),
|
||||
discarded_paths=list(self._discarded),
|
||||
metadata=dict(self._metadata),
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear scratchpad."""
|
||||
self._hypotheses.clear()
|
||||
self._partial.clear()
|
||||
self._discarded.clear()
|
||||
self._metadata.clear()
|
||||
|
||||
def set_metadata(self, key: str, value: Any) -> None:
|
||||
"""Set metadata entry."""
|
||||
self._metadata[key] = value
|
||||
|
||||
def get_metadata(self, key: str, default: Any = None) -> Any:
|
||||
"""Get metadata entry."""
|
||||
return self._metadata.get(key, default)
|
||||
55
fusionagi/memory/semantic.py
Normal file
55
fusionagi/memory/semantic.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Semantic memory: facts, policies, domain knowledge for AGI."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class SemanticMemory:
|
||||
"""Stores facts, policies, domain knowledge. Queryable and updatable."""
|
||||
|
||||
def __init__(self, max_facts: int = 50000) -> None:
|
||||
self._facts: dict[str, dict[str, Any]] = {}
|
||||
self._by_key: dict[str, list[str]] = {}
|
||||
self._max_facts = max_facts
|
||||
|
||||
def add_fact(self, fact_id: str, statement: str, source: str = "", domain: str = "", metadata: dict[str, Any] | None = None) -> None:
|
||||
if len(self._facts) >= self._max_facts and fact_id not in self._facts:
|
||||
self._evict_one()
|
||||
entry = {"statement": statement, "source": source, "domain": domain, "metadata": metadata or {}}
|
||||
self._facts[fact_id] = entry
|
||||
if domain:
|
||||
self._by_key.setdefault(domain, []).append(fact_id)
|
||||
logger.debug("Semantic memory: fact added", extra={"fact_id": fact_id, "domain": domain})
|
||||
|
||||
def get_fact(self, fact_id: str) -> dict[str, Any] | None:
|
||||
return self._facts.get(fact_id)
|
||||
|
||||
def query(self, domain: str | None = None, limit: int = 100) -> list[dict[str, Any]]:
|
||||
if domain:
|
||||
ids = self._by_key.get(domain, [])[-limit:]
|
||||
return [self._facts[id] for id in ids if id in self._facts]
|
||||
return list(self._facts.values())[-limit:]
|
||||
|
||||
def update_fact(self, fact_id: str, **kwargs: Any) -> bool:
|
||||
if fact_id not in self._facts:
|
||||
return False
|
||||
for k, v in kwargs.items():
|
||||
if k in self._facts[fact_id]:
|
||||
self._facts[fact_id][k] = v
|
||||
return True
|
||||
|
||||
def forget(self, fact_id: str) -> bool:
|
||||
if fact_id not in self._facts:
|
||||
return False
|
||||
entry = self._facts.pop(fact_id)
|
||||
domain = entry.get("domain", "")
|
||||
if domain and fact_id in self._by_key.get(domain, []):
|
||||
self._by_key[domain] = [x for x in self._by_key[domain] if x != fact_id]
|
||||
return True
|
||||
|
||||
def _evict_one(self) -> None:
|
||||
if not self._facts:
|
||||
return
|
||||
rid = next(iter(self._facts))
|
||||
self.forget(rid)
|
||||
106
fusionagi/memory/semantic_graph.py
Normal file
106
fusionagi/memory/semantic_graph.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Semantic memory graph: nodes = AtomicSemanticUnit, edges = SemanticRelation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.atomic import (
|
||||
AtomicSemanticUnit,
|
||||
AtomicUnitType,
|
||||
SemanticRelation,
|
||||
)
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class SemanticGraphMemory:
|
||||
"""
|
||||
Graph-backed semantic memory: nodes = atomic units, edges = relations.
|
||||
|
||||
Supports add_unit, add_relation, query_units, query_neighbors, query_by_type.
|
||||
In-memory implementation with dict + adjacency list.
|
||||
"""
|
||||
|
||||
def __init__(self, max_units: int = 50000) -> None:
|
||||
self._units: dict[str, AtomicSemanticUnit] = {}
|
||||
self._by_type: dict[AtomicUnitType, list[str]] = defaultdict(list)
|
||||
self._outgoing: dict[str, list[SemanticRelation]] = defaultdict(list)
|
||||
self._incoming: dict[str, list[SemanticRelation]] = defaultdict(list)
|
||||
self._max_units = max_units
|
||||
|
||||
def add_unit(self, unit: AtomicSemanticUnit) -> None:
|
||||
"""Add an atomic semantic unit."""
|
||||
if len(self._units) >= self._max_units and unit.unit_id not in self._units:
|
||||
self._evict_one()
|
||||
self._units[unit.unit_id] = unit
|
||||
self._by_type[unit.type].append(unit.unit_id)
|
||||
logger.debug("Semantic graph: unit added", extra={"unit_id": unit.unit_id, "type": unit.type.value})
|
||||
|
||||
def add_relation(self, relation: SemanticRelation) -> None:
|
||||
"""Add a relation between units."""
|
||||
if relation.from_id in self._units and relation.to_id in self._units:
|
||||
self._outgoing[relation.from_id].append(relation)
|
||||
self._incoming[relation.to_id].append(relation)
|
||||
|
||||
def get_unit(self, unit_id: str) -> AtomicSemanticUnit | None:
|
||||
"""Get unit by ID."""
|
||||
return self._units.get(unit_id)
|
||||
|
||||
def query_units(
|
||||
self,
|
||||
unit_ids: list[str] | None = None,
|
||||
unit_type: AtomicUnitType | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[AtomicSemanticUnit]:
|
||||
"""Query units by IDs or type."""
|
||||
if unit_ids:
|
||||
return [self._units[uid] for uid in unit_ids if uid in self._units][:limit]
|
||||
if unit_type:
|
||||
ids = self._by_type.get(unit_type, [])[-limit:]
|
||||
return [self._units[uid] for uid in ids if uid in self._units]
|
||||
return list(self._units.values())[-limit:]
|
||||
|
||||
def query_neighbors(
|
||||
self,
|
||||
unit_id: str,
|
||||
direction: str = "outgoing",
|
||||
relation_type: str | None = None,
|
||||
) -> list[tuple[AtomicSemanticUnit, SemanticRelation]]:
|
||||
"""Get neighboring units and relations."""
|
||||
edges = self._outgoing[unit_id] if direction == "outgoing" else self._incoming[unit_id]
|
||||
results: list[tuple[AtomicSemanticUnit, SemanticRelation]] = []
|
||||
for rel in edges:
|
||||
if relation_type and rel.relation_type.value != relation_type:
|
||||
continue
|
||||
other_id = rel.to_id if direction == "outgoing" else rel.from_id
|
||||
other = self._units.get(other_id)
|
||||
if other:
|
||||
results.append((other, rel))
|
||||
return results
|
||||
|
||||
def query_by_type(self, unit_type: AtomicUnitType, limit: int = 100) -> list[AtomicSemanticUnit]:
|
||||
"""Query units by type."""
|
||||
return self.query_units(unit_type=unit_type, limit=limit)
|
||||
|
||||
def ingest_decomposition(
|
||||
self,
|
||||
units: list[AtomicSemanticUnit],
|
||||
relations: list[SemanticRelation],
|
||||
) -> None:
|
||||
"""Ingest a DecompositionResult into the graph."""
|
||||
for u in units:
|
||||
self.add_unit(u)
|
||||
for r in relations:
|
||||
self.add_relation(r)
|
||||
|
||||
def _evict_one(self) -> None:
|
||||
"""Evict oldest unit (simple FIFO on first key)."""
|
||||
if not self._units:
|
||||
return
|
||||
uid = next(iter(self._units))
|
||||
unit = self._units.pop(uid, None)
|
||||
if unit:
|
||||
self._by_type[unit.type] = [x for x in self._by_type[unit.type] if x != uid]
|
||||
self._outgoing.pop(uid, None)
|
||||
self._incoming.pop(uid, None)
|
||||
logger.debug("Semantic graph: evicted unit", extra={"unit_id": uid})
|
||||
97
fusionagi/memory/service.py
Normal file
97
fusionagi/memory/service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Unified memory service: session, episodic, semantic, vector with tenant isolation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.memory.working import WorkingMemory
|
||||
from fusionagi.memory.episodic import EpisodicMemory
|
||||
from fusionagi.memory.semantic import SemanticMemory
|
||||
|
||||
|
||||
def _scoped_key(tenant_id: str, user_id: str, base: str) -> str:
|
||||
"""Scope key by tenant and user."""
|
||||
parts = [tenant_id or "default", user_id or "anonymous", base]
|
||||
return ":".join(parts)
|
||||
|
||||
|
||||
class VectorMemory:
|
||||
"""
|
||||
Vector memory for embeddings retrieval.
|
||||
Stub implementation; replace with pgvector or Pinecone adapter for production.
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 10000) -> None:
|
||||
self._store: list[dict[str, Any]] = []
|
||||
self._max_entries = max_entries
|
||||
|
||||
def add(self, id: str, embedding: list[float], metadata: dict[str, Any] | None = None) -> None:
|
||||
"""Add embedding (stub: stores in-memory)."""
|
||||
if len(self._store) >= self._max_entries:
|
||||
self._store.pop(0)
|
||||
self._store.append({"id": id, "embedding": embedding, "metadata": metadata or {}})
|
||||
|
||||
def search(self, query_embedding: list[float], top_k: int = 10) -> list[dict[str, Any]]:
|
||||
"""Search by embedding (stub: returns empty)."""
|
||||
return []
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
Unified memory service with tenant isolation.
|
||||
Wraps WorkingMemory (session), EpisodicMemory, SemanticMemory, VectorMemory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str = "default",
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id or "anonymous"
|
||||
self._working = WorkingMemory()
|
||||
self._episodic = EpisodicMemory()
|
||||
self._semantic = SemanticMemory()
|
||||
self._vector = VectorMemory()
|
||||
|
||||
@property
|
||||
def session(self) -> WorkingMemory:
|
||||
"""Short-term session memory."""
|
||||
return self._working
|
||||
|
||||
@property
|
||||
def episodic(self) -> EpisodicMemory:
|
||||
"""Episodic memory (what happened, decisions, outcomes)."""
|
||||
return self._episodic
|
||||
|
||||
@property
|
||||
def semantic(self) -> SemanticMemory:
|
||||
"""Semantic memory (facts, preferences)."""
|
||||
return self._semantic
|
||||
|
||||
@property
|
||||
def vector(self) -> VectorMemory:
|
||||
"""Vector memory (embeddings for retrieval)."""
|
||||
return self._vector
|
||||
|
||||
def scope_session(self, session_id: str) -> str:
|
||||
"""Return tenant/user scoped session key."""
|
||||
return _scoped_key(self._tenant_id, self._user_id, session_id)
|
||||
|
||||
def get(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""Get from session memory (scoped)."""
|
||||
scoped = self.scope_session(session_id)
|
||||
return self._working.get(scoped, key, default)
|
||||
|
||||
def set(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""Set in session memory (scoped)."""
|
||||
scoped = self.scope_session(session_id)
|
||||
self._working.set(scoped, key, value)
|
||||
|
||||
def append_episode(self, task_id: str, event: dict[str, Any], event_type: str | None = None) -> int:
|
||||
"""Append to episodic memory (with tenant in metadata)."""
|
||||
event = dict(event)
|
||||
meta = event.setdefault("metadata", {})
|
||||
meta = dict(meta) if meta else {}
|
||||
meta["tenant_id"] = self._tenant_id
|
||||
meta["user_id"] = self._user_id
|
||||
event["metadata"] = meta
|
||||
return self._episodic.append(task_id, event, event_type)
|
||||
79
fusionagi/memory/sharding.py
Normal file
79
fusionagi/memory/sharding.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Context sharding: cluster atomic units by semantic similarity or domain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit
|
||||
|
||||
|
||||
@dataclass
|
||||
class Shard:
|
||||
"""A cluster of atomic units with optional summary and embedding."""
|
||||
|
||||
shard_id: str = field(default_factory=lambda: f"shard_{uuid.uuid4().hex[:12]}")
|
||||
unit_ids: list[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
embedding: list[float] | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _extract_keywords(text: str) -> set[str]:
|
||||
"""Extract keywords for clustering."""
|
||||
content = " ".join(text.lower().split())
|
||||
return set(re.findall(r"\b[a-z0-9]{3,}\b", content))
|
||||
|
||||
|
||||
def _keyword_similarity(a: set[str], b: set[str]) -> float:
|
||||
"""Jaccard similarity between keyword sets."""
|
||||
if not a and not b:
|
||||
return 1.0
|
||||
inter = len(a & b)
|
||||
union = len(a | b)
|
||||
return inter / union if union else 0.0
|
||||
|
||||
|
||||
def _cluster_by_keywords(
|
||||
units: list[AtomicSemanticUnit],
|
||||
max_cluster_size: int,
|
||||
) -> list[list[AtomicSemanticUnit]]:
|
||||
"""Cluster units by keyword overlap (greedy)."""
|
||||
if not units:
|
||||
return []
|
||||
if len(units) <= max_cluster_size:
|
||||
return [units]
|
||||
unit_keywords: list[set[str]] = [_extract_keywords(u.content) for u in units]
|
||||
clusters: list[list[int]] = []
|
||||
assigned: set[int] = set()
|
||||
for i in range(len(units)):
|
||||
if i in assigned:
|
||||
continue
|
||||
cluster = [i]
|
||||
assigned.add(i)
|
||||
for j in range(i + 1, len(units)):
|
||||
if j in assigned or len(cluster) >= max_cluster_size:
|
||||
continue
|
||||
sim = _keyword_similarity(unit_keywords[i], unit_keywords[j])
|
||||
if sim > 0.1:
|
||||
cluster.append(j)
|
||||
assigned.add(j)
|
||||
clusters.append(cluster)
|
||||
return [[units[idx] for idx in c] for c in clusters]
|
||||
|
||||
|
||||
def shard_context(
|
||||
units: list[AtomicSemanticUnit],
|
||||
max_cluster_size: int = 20,
|
||||
) -> list[Shard]:
|
||||
"""Shard atomic units into clusters by semantic similarity."""
|
||||
clusters = _cluster_by_keywords(units, max_cluster_size)
|
||||
shards: list[Shard] = []
|
||||
for cluster in clusters:
|
||||
unit_ids = [u.unit_id for u in cluster]
|
||||
summary_parts = [u.content[:80] for u in cluster[:3]]
|
||||
summary = "; ".join(summary_parts) + ("..." if len(cluster) > 3 else "")
|
||||
shards.append(Shard(unit_ids=unit_ids, summary=summary, metadata={"unit_count": len(cluster)}))
|
||||
return shards
|
||||
134
fusionagi/memory/thought_versioning.py
Normal file
134
fusionagi/memory/thought_versioning.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Versioned thought states: snapshots, rollback, branching."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.memory.scratchpad import ThoughtState
|
||||
from fusionagi.reasoning.tot import ThoughtNode
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtStateSnapshot:
|
||||
"""Snapshot of reasoning state: tree + scratchpad."""
|
||||
|
||||
version_id: str = field(default_factory=lambda: f"v_{uuid.uuid4().hex[:12]}")
|
||||
tree_state: dict[str, Any] | None = None
|
||||
scratchpad_state: ThoughtState | None = None
|
||||
timestamp: float = field(default_factory=time.monotonic)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _serialize_tree(node: ThoughtNode | None) -> dict[str, Any]:
|
||||
"""Serialize ThoughtNode to dict."""
|
||||
if node is None:
|
||||
return {}
|
||||
return {
|
||||
"node_id": node.node_id,
|
||||
"parent_id": node.parent_id,
|
||||
"thought": node.thought,
|
||||
"trace": node.trace,
|
||||
"score": node.score,
|
||||
"depth": node.depth,
|
||||
"unit_refs": node.unit_refs,
|
||||
"metadata": node.metadata,
|
||||
"children": [_serialize_tree(c) for c in node.children],
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_tree(data: dict) -> ThoughtNode | None:
|
||||
"""Deserialize dict to ThoughtNode."""
|
||||
if not data:
|
||||
return None
|
||||
node = ThoughtNode(
|
||||
node_id=data.get("node_id", ""),
|
||||
parent_id=data.get("parent_id"),
|
||||
thought=data.get("thought", ""),
|
||||
trace=data.get("trace", []),
|
||||
score=float(data.get("score", 0)),
|
||||
depth=int(data.get("depth", 0)),
|
||||
unit_refs=list(data.get("unit_refs", [])),
|
||||
metadata=dict(data.get("metadata", {})),
|
||||
)
|
||||
for c in data.get("children", []):
|
||||
child = _deserialize_tree(c)
|
||||
if child:
|
||||
node.children.append(child)
|
||||
return node
|
||||
|
||||
|
||||
class ThoughtVersioning:
|
||||
"""Save, load, rollback, branch thought states."""
|
||||
|
||||
def __init__(self, max_snapshots: int = 50) -> None:
|
||||
self._snapshots: dict[str, ThoughtStateSnapshot] = {}
|
||||
self._max_snapshots = max_snapshots
|
||||
|
||||
def save_snapshot(
|
||||
self,
|
||||
tree: ThoughtNode | None,
|
||||
scratchpad: ThoughtState | None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Save snapshot; return version_id."""
|
||||
snapshot = ThoughtStateSnapshot(
|
||||
tree_state=_serialize_tree(tree) if tree else {},
|
||||
scratchpad_state=scratchpad,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._snapshots[snapshot.version_id] = snapshot
|
||||
if len(self._snapshots) > self._max_snapshots:
|
||||
oldest = min(self._snapshots.keys(), key=lambda k: self._snapshots[k].timestamp)
|
||||
del self._snapshots[oldest]
|
||||
logger.debug("Thought snapshot saved", extra={"version_id": snapshot.version_id})
|
||||
return snapshot.version_id
|
||||
|
||||
def load_snapshot(
|
||||
self,
|
||||
version_id: str,
|
||||
) -> tuple[ThoughtNode | None, ThoughtState | None]:
|
||||
"""Load snapshot; return (tree, scratchpad)."""
|
||||
snap = self._snapshots.get(version_id)
|
||||
if not snap:
|
||||
return None, None
|
||||
tree = _deserialize_tree(snap.tree_state or {}) if snap.tree_state else None
|
||||
return tree, snap.scratchpad_state
|
||||
|
||||
def list_snapshots(self) -> list[dict[str, Any]]:
|
||||
"""List available snapshots."""
|
||||
return [
|
||||
{
|
||||
"version_id": v.version_id,
|
||||
"timestamp": v.timestamp,
|
||||
"metadata": v.metadata,
|
||||
}
|
||||
for v in self._snapshots.values()
|
||||
]
|
||||
|
||||
def rollback_to(
|
||||
self,
|
||||
version_id: str,
|
||||
) -> tuple[ThoughtNode | None, ThoughtState | None]:
|
||||
"""Load and return snapshot (alias for load_snapshot)."""
|
||||
return self.load_snapshot(version_id)
|
||||
|
||||
def branch_from(
|
||||
self,
|
||||
version_id: str,
|
||||
) -> tuple[ThoughtNode | None, ThoughtState | None]:
|
||||
"""Branch from snapshot (returns copy for further edits)."""
|
||||
tree, scratchpad = self.load_snapshot(version_id)
|
||||
if tree:
|
||||
tree = _deserialize_tree(_serialize_tree(tree))
|
||||
if scratchpad:
|
||||
scratchpad = ThoughtState(
|
||||
hypotheses=list(scratchpad.hypotheses),
|
||||
partial_conclusions=list(scratchpad.partial_conclusions),
|
||||
discarded_paths=list(scratchpad.discarded_paths),
|
||||
metadata=dict(scratchpad.metadata),
|
||||
)
|
||||
return tree, scratchpad
|
||||
64
fusionagi/memory/trust.py
Normal file
64
fusionagi/memory/trust.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Trust memory: verified vs unverified, provenance, confidence decay for AGI."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class TrustMemory:
|
||||
"""
|
||||
Tracks verified vs unverified claims, source provenance, and optional
|
||||
confidence decay over time.
|
||||
"""
|
||||
|
||||
def __init__(self, decay_enabled: bool = False) -> None:
|
||||
self._entries: dict[str, dict[str, Any]] = {} # claim_id -> {verified, source, confidence, created_at}
|
||||
self._decay_enabled = decay_enabled
|
||||
|
||||
def add(
|
||||
self,
|
||||
claim_id: str,
|
||||
verified: bool,
|
||||
source: str = "",
|
||||
confidence: float = 1.0,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Record a claim with verification status and provenance."""
|
||||
self._entries[claim_id] = {
|
||||
"verified": verified,
|
||||
"source": source,
|
||||
"confidence": confidence,
|
||||
"created_at": _utc_now(),
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
logger.debug("Trust memory: claim added", extra={"claim_id": claim_id, "verified": verified})
|
||||
|
||||
def get(self, claim_id: str) -> dict[str, Any] | None:
|
||||
"""Return trust entry or None. Applies decay if enabled."""
|
||||
e = self._entries.get(claim_id)
|
||||
if not e:
|
||||
return None
|
||||
if self._decay_enabled:
|
||||
# Simple decay: reduce confidence by 0.01 per day (placeholder)
|
||||
from datetime import timedelta
|
||||
age_days = (_utc_now() - e["created_at"]).total_seconds() / 86400
|
||||
e = dict(e)
|
||||
e["confidence"] = max(0.0, e["confidence"] - 0.01 * age_days)
|
||||
return e
|
||||
|
||||
def is_verified(self, claim_id: str) -> bool:
|
||||
"""Return True if claim is marked verified."""
|
||||
e = self._entries.get(claim_id)
|
||||
return e.get("verified", False) if e else False
|
||||
|
||||
def set_verified(self, claim_id: str, verified: bool) -> bool:
|
||||
"""Update verified status. Returns True if claim existed."""
|
||||
if claim_id not in self._entries:
|
||||
return False
|
||||
self._entries[claim_id]["verified"] = verified
|
||||
return True
|
||||
101
fusionagi/memory/vector_pgvector.py
Normal file
101
fusionagi/memory/vector_pgvector.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""pgvector-backed VectorMemory adapter. Requires: pip install fusionagi[vector]."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
def create_vector_memory_pgvector(
|
||||
connection_string: str,
|
||||
table_name: str = "embeddings",
|
||||
dimension: int = 1536,
|
||||
) -> Any:
|
||||
"""
|
||||
Create pgvector-backed VectorMemory when pgvector is installed.
|
||||
Returns None if pgvector/database unavailable.
|
||||
"""
|
||||
try:
|
||||
import pgvector
|
||||
from pgvector.psycopg import register_vector
|
||||
except ImportError:
|
||||
logger.debug("pgvector not installed; use pip install fusionagi[vector]")
|
||||
return None
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
except ImportError:
|
||||
logger.debug("psycopg not installed; use pip install fusionagi[memory]")
|
||||
return None
|
||||
|
||||
return VectorMemoryPgvector(connection_string, table_name, dimension)
|
||||
|
||||
|
||||
class VectorMemoryPgvector:
|
||||
"""VectorMemory implementation using pgvector."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
table_name: str = "embeddings",
|
||||
dimension: int = 1536,
|
||||
) -> None:
|
||||
import pgvector
|
||||
from pgvector.psycopg import register_vector
|
||||
|
||||
self._conn_str = connection_string
|
||||
self._table = table_name
|
||||
self._dim = dimension
|
||||
|
||||
with psycopg.connect(connection_string) as conn:
|
||||
register_vector(conn)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
embedding vector({dimension}),
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def add(self, id: str, embedding: list[float], metadata: dict[str, Any] | None = None) -> None:
|
||||
import json
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
|
||||
with psycopg.connect(self._conn_str) as conn:
|
||||
register_vector(conn)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
INSERT INTO {self._table} (id, embedding, metadata)
|
||||
VALUES (%s, %s, %s)
|
||||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, metadata = EXCLUDED.metadata
|
||||
""",
|
||||
(id, embedding, json.dumps(metadata or {})),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def search(self, query_embedding: list[float], top_k: int = 10) -> list[dict[str, Any]]:
|
||||
import json
|
||||
import psycopg
|
||||
from pgvector.psycopg import register_vector
|
||||
|
||||
with psycopg.connect(self._conn_str) as conn:
|
||||
register_vector(conn)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT id, metadata
|
||||
FROM {self._table}
|
||||
ORDER BY embedding <-> %s
|
||||
LIMIT %s
|
||||
""",
|
||||
(query_embedding, top_k),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [{"id": r[0], "metadata": json.loads(r[1]) if r[1] else {}} for r in rows]
|
||||
150
fusionagi/memory/working.py
Normal file
150
fusionagi/memory/working.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Working memory: in-memory key-value / list per task/session.
|
||||
|
||||
Working memory provides short-term storage for active tasks:
|
||||
- Key-value storage per session/task
|
||||
- List append operations for accumulating results
|
||||
- Context retrieval for reasoning
|
||||
- Session lifecycle management
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterator
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi._time import utc_now
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
Short-term working memory per task/session.
|
||||
|
||||
Features:
|
||||
- Key-value get/set operations
|
||||
- List append with automatic coercion
|
||||
- Context summary for LLM prompts
|
||||
- Session management and cleanup
|
||||
- Size limits to prevent unbounded growth
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries_per_session: int = 1000) -> None:
|
||||
"""
|
||||
Initialize working memory.
|
||||
|
||||
Args:
|
||||
max_entries_per_session: Maximum entries per session before oldest are removed.
|
||||
"""
|
||||
self._store: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
self._timestamps: dict[str, datetime] = {}
|
||||
self._max_entries = max_entries_per_session
|
||||
|
||||
def get(self, session_id: str, key: str, default: Any = None) -> Any:
|
||||
"""Get value for session and key; returns default if not found."""
|
||||
return self._store[session_id].get(key, default)
|
||||
|
||||
def set(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""Set value for session and key."""
|
||||
self._store[session_id][key] = value
|
||||
self._timestamps[session_id] = utc_now()
|
||||
self._enforce_limits(session_id)
|
||||
|
||||
def append(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""Append to list for session and key (creates list if needed)."""
|
||||
if key not in self._store[session_id]:
|
||||
self._store[session_id][key] = []
|
||||
lst = self._store[session_id][key]
|
||||
if not isinstance(lst, list):
|
||||
lst = [lst]
|
||||
self._store[session_id][key] = lst
|
||||
lst.append(value)
|
||||
self._timestamps[session_id] = utc_now()
|
||||
self._enforce_limits(session_id)
|
||||
|
||||
def get_list(self, session_id: str, key: str) -> list[Any]:
|
||||
"""Return list for session and key (copy)."""
|
||||
val = self._store[session_id].get(key)
|
||||
if isinstance(val, list):
|
||||
return list(val)
|
||||
return [val] if val is not None else []
|
||||
|
||||
def has(self, session_id: str, key: str) -> bool:
|
||||
"""Check if a key exists in session."""
|
||||
return key in self._store.get(session_id, {})
|
||||
|
||||
def keys(self, session_id: str) -> list[str]:
|
||||
"""Return all keys for a session."""
|
||||
return list(self._store.get(session_id, {}).keys())
|
||||
|
||||
def delete(self, session_id: str, key: str) -> bool:
|
||||
"""Delete a key from session. Returns True if existed."""
|
||||
if session_id in self._store and key in self._store[session_id]:
|
||||
del self._store[session_id][key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""Clear all data for a session."""
|
||||
self._store.pop(session_id, None)
|
||||
self._timestamps.pop(session_id, None)
|
||||
|
||||
def get_context_summary(self, session_id: str, max_items: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Get a summary of working memory for context injection.
|
||||
|
||||
Useful for including relevant context in LLM prompts.
|
||||
"""
|
||||
session_data = self._store.get(session_id, {})
|
||||
summary = {}
|
||||
|
||||
for key, value in list(session_data.items())[:max_items]:
|
||||
if isinstance(value, list):
|
||||
# For lists, include count and last few items
|
||||
summary[key] = {
|
||||
"type": "list",
|
||||
"count": len(value),
|
||||
"recent": value[-3:] if len(value) > 3 else value,
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
# For dicts, include keys
|
||||
summary[key] = {
|
||||
"type": "dict",
|
||||
"keys": list(value.keys())[:10],
|
||||
}
|
||||
else:
|
||||
# For scalars, include the value (truncated if string)
|
||||
if isinstance(value, str) and len(value) > 200:
|
||||
summary[key] = value[:200] + "..."
|
||||
else:
|
||||
summary[key] = value
|
||||
|
||||
return summary
|
||||
|
||||
def get_all(self, session_id: str) -> dict[str, Any]:
|
||||
"""Return all data for a session (copy)."""
|
||||
return dict(self._store.get(session_id, {}))
|
||||
|
||||
def session_exists(self, session_id: str) -> bool:
|
||||
"""Check if a session has any data."""
|
||||
return session_id in self._store and bool(self._store[session_id])
|
||||
|
||||
def active_sessions(self) -> list[str]:
|
||||
"""Return list of sessions with data."""
|
||||
return [sid for sid, data in self._store.items() if data]
|
||||
|
||||
def session_count(self) -> int:
|
||||
"""Return number of active sessions."""
|
||||
return len([s for s in self._store.values() if s])
|
||||
|
||||
def _enforce_limits(self, session_id: str) -> None:
|
||||
"""Enforce size limits on session data."""
|
||||
session_data = self._store.get(session_id, {})
|
||||
total_items = sum(
|
||||
len(v) if isinstance(v, (list, dict)) else 1
|
||||
for v in session_data.values()
|
||||
)
|
||||
|
||||
if total_items > self._max_entries:
|
||||
logger.warning(
|
||||
"Working memory size limit exceeded",
|
||||
extra={"session_id": session_id, "items": total_items},
|
||||
)
|
||||
41
fusionagi/multi_agent/__init__.py
Normal file
41
fusionagi/multi_agent/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Multi-agent: parallel, delegation, pooling, coordinator, adversarial reviewer, consensus."""
|
||||
|
||||
from fusionagi.multi_agent.parallel import (
|
||||
execute_steps_parallel,
|
||||
execute_steps_parallel_wave,
|
||||
ParallelStepResult,
|
||||
)
|
||||
from fusionagi.multi_agent.pool import AgentPool, PooledExecutorRouter
|
||||
from fusionagi.multi_agent.supervisor import SupervisorAgent
|
||||
from fusionagi.multi_agent.delegation import (
|
||||
delegate_sub_tasks,
|
||||
DelegationConfig,
|
||||
SubTask,
|
||||
SubTaskResult,
|
||||
)
|
||||
from fusionagi.multi_agent.coordinator import CoordinatorAgent
|
||||
from fusionagi.multi_agent.consensus import consensus_vote, arbitrate
|
||||
from fusionagi.multi_agent.consensus_engine import (
|
||||
run_consensus,
|
||||
collect_claims,
|
||||
CollectedClaim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"execute_steps_parallel",
|
||||
"execute_steps_parallel_wave",
|
||||
"ParallelStepResult",
|
||||
"AgentPool",
|
||||
"PooledExecutorRouter",
|
||||
"SupervisorAgent",
|
||||
"delegate_sub_tasks",
|
||||
"DelegationConfig",
|
||||
"SubTask",
|
||||
"SubTaskResult",
|
||||
"CoordinatorAgent",
|
||||
"consensus_vote",
|
||||
"arbitrate",
|
||||
"run_consensus",
|
||||
"collect_claims",
|
||||
"CollectedClaim",
|
||||
]
|
||||
15
fusionagi/multi_agent/consensus.py
Normal file
15
fusionagi/multi_agent/consensus.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Any
|
||||
from collections import Counter
|
||||
from fusionagi._logger import logger
|
||||
|
||||
def consensus_vote(answers: list, key=None):
|
||||
if not answers:
|
||||
return None
|
||||
values = [a.get(key, a) if isinstance(a, dict) else a for a in answers] if key else list(answers)
|
||||
return Counter(values).most_common(1)[0][0]
|
||||
|
||||
def arbitrate(proposals: list, arbitrator="coordinator"):
|
||||
if not proposals:
|
||||
return {}
|
||||
logger.info("Arbitrate", extra={"arbitrator": arbitrator, "count": len(proposals)})
|
||||
return proposals[0]
|
||||
194
fusionagi/multi_agent/consensus_engine.py
Normal file
194
fusionagi/multi_agent/consensus_engine.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Consensus engine: claim collection, deduplication, conflict detection, scoring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim
|
||||
from fusionagi.schemas.witness import AgreementMap
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectedClaim:
|
||||
"""Claim with source head and metadata for consensus."""
|
||||
|
||||
claim_text: str
|
||||
confidence: float
|
||||
head_id: HeadId
|
||||
evidence_count: int
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def _normalize_text(s: str) -> str:
|
||||
"""Normalize for duplicate detection."""
|
||||
return " ".join(s.lower().split())
|
||||
|
||||
|
||||
def _are_similar(a: str, b: str, threshold: float = 0.9) -> bool:
|
||||
"""Simple similarity: exact match or one contains the other (normalized)."""
|
||||
na, nb = _normalize_text(a), _normalize_text(b)
|
||||
if na == nb:
|
||||
return True
|
||||
if len(na) < 10 or len(nb) < 10:
|
||||
return na == nb
|
||||
# Jaccard-like: word overlap
|
||||
wa, wb = set(na.split()), set(nb.split())
|
||||
inter = len(wa & wb)
|
||||
union = len(wa | wb)
|
||||
if union == 0:
|
||||
return False
|
||||
return (inter / union) >= threshold
|
||||
|
||||
|
||||
def _looks_contradictory(a: str, b: str) -> bool:
|
||||
"""Heuristic: same subject with opposite polarity indicators."""
|
||||
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
|
||||
na, nb = _normalize_text(a), _normalize_text(b)
|
||||
wa, wb = set(na.split()), set(nb.split())
|
||||
# If one has neg and the other doesn't, and they share significant overlap
|
||||
a_neg = bool(wa & neg_words)
|
||||
b_neg = bool(wb & neg_words)
|
||||
if a_neg != b_neg:
|
||||
overlap = len(wa & wb) / max(len(wa), 1)
|
||||
if overlap > 0.3:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def collect_claims(outputs: list[HeadOutput]) -> list[CollectedClaim]:
|
||||
"""Flatten all head claims with source metadata."""
|
||||
collected: list[CollectedClaim] = []
|
||||
for out in outputs:
|
||||
for c in out.claims:
|
||||
collected.append(
|
||||
CollectedClaim(
|
||||
claim_text=c.claim_text,
|
||||
confidence=c.confidence,
|
||||
head_id=out.head_id,
|
||||
evidence_count=len(c.evidence),
|
||||
raw={
|
||||
"claim_text": c.claim_text,
|
||||
"confidence": c.confidence,
|
||||
"head_id": out.head_id.value,
|
||||
"evidence_count": len(c.evidence),
|
||||
"assumptions": c.assumptions,
|
||||
},
|
||||
)
|
||||
)
|
||||
return collected
|
||||
|
||||
|
||||
def run_consensus(
|
||||
outputs: list[HeadOutput],
|
||||
head_weights: dict[HeadId, float] | None = None,
|
||||
confidence_threshold: float = 0.5,
|
||||
) -> AgreementMap:
|
||||
"""
|
||||
Run consensus: deduplicate, detect conflicts, score, produce AgreementMap.
|
||||
|
||||
Args:
|
||||
outputs: HeadOutput from all content heads.
|
||||
head_weights: Optional per-head reliability weights (default 1.0).
|
||||
confidence_threshold: Minimum confidence for agreed claim.
|
||||
|
||||
Returns:
|
||||
AgreementMap with agreed_claims, disputed_claims, confidence_score.
|
||||
"""
|
||||
if not outputs:
|
||||
return AgreementMap(
|
||||
agreed_claims=[],
|
||||
disputed_claims=[],
|
||||
confidence_score=0.0,
|
||||
)
|
||||
|
||||
weights = head_weights or {h: 1.0 for h in HeadId}
|
||||
collected = collect_claims(outputs)
|
||||
|
||||
# Group by similarity (merge near-duplicates)
|
||||
merged: list[CollectedClaim] = []
|
||||
used: set[int] = set()
|
||||
for i, ca in enumerate(collected):
|
||||
if i in used:
|
||||
continue
|
||||
group = [ca]
|
||||
used.add(i)
|
||||
for j, cb in enumerate(collected):
|
||||
if j in used:
|
||||
continue
|
||||
if _are_similar(ca.claim_text, cb.claim_text) and not _looks_contradictory(ca.claim_text, cb.claim_text):
|
||||
group.append(cb)
|
||||
used.add(j)
|
||||
# Aggregate: weighted avg confidence, combine heads
|
||||
if len(group) == 1:
|
||||
c = group[0]
|
||||
score = c.confidence * weights.get(c.head_id, 1.0)
|
||||
if c.evidence_count > 0:
|
||||
score *= 1.1 # boost for citations
|
||||
merged.append(
|
||||
CollectedClaim(
|
||||
claim_text=c.claim_text,
|
||||
confidence=score,
|
||||
head_id=c.head_id,
|
||||
evidence_count=c.evidence_count,
|
||||
raw={**c.raw, "aggregated_confidence": score, "supporting_heads": [c.head_id.value]},
|
||||
)
|
||||
)
|
||||
else:
|
||||
total_conf = sum(g.confidence * weights.get(g.head_id, 1.0) for g in group)
|
||||
avg_conf = total_conf / len(group)
|
||||
evidence_boost = 1.1 if any(g.evidence_count > 0 for g in group) else 1.0
|
||||
score = min(1.0, avg_conf * evidence_boost)
|
||||
merged.append(
|
||||
CollectedClaim(
|
||||
claim_text=group[0].claim_text,
|
||||
confidence=score,
|
||||
head_id=group[0].head_id,
|
||||
evidence_count=sum(g.evidence_count for g in group),
|
||||
raw={
|
||||
"claim_text": group[0].claim_text,
|
||||
"aggregated_confidence": score,
|
||||
"supporting_heads": [g.head_id.value for g in group],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Conflict detection
|
||||
agreed: list[dict[str, Any]] = []
|
||||
disputed: list[dict[str, Any]] = []
|
||||
|
||||
for c in merged:
|
||||
in_conflict = False
|
||||
for d in merged:
|
||||
if c is d:
|
||||
continue
|
||||
if _looks_contradictory(c.claim_text, d.claim_text):
|
||||
in_conflict = True
|
||||
break
|
||||
rec = {
|
||||
"claim_text": c.claim_text,
|
||||
"confidence": c.confidence,
|
||||
"supporting_heads": c.raw.get("supporting_heads", [c.head_id.value]),
|
||||
}
|
||||
if in_conflict or c.confidence < confidence_threshold:
|
||||
disputed.append(rec)
|
||||
else:
|
||||
agreed.append(rec)
|
||||
|
||||
overall_conf = (
|
||||
sum(a["confidence"] for a in agreed) / len(agreed)
|
||||
if agreed
|
||||
else 0.0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Consensus complete",
|
||||
extra={"agreed": len(agreed), "disputed": len(disputed), "confidence": overall_conf},
|
||||
)
|
||||
|
||||
return AgreementMap(
|
||||
agreed_claims=agreed,
|
||||
disputed_claims=disputed,
|
||||
confidence_score=min(1.0, overall_conf),
|
||||
)
|
||||
18
fusionagi/multi_agent/coordinator.py
Normal file
18
fusionagi/multi_agent/coordinator.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.schemas.messages import AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.orchestrator import Orchestrator
|
||||
from fusionagi.core.goal_manager import GoalManager
|
||||
|
||||
class CoordinatorAgent(BaseAgent):
|
||||
def __init__(self, identity="coordinator", orchestrator=None, goal_manager=None, planner_id="planner"):
|
||||
super().__init__(identity=identity, role="Coordinator", objective="Own goals and assign tasks", memory_access=True, tool_permissions=[])
|
||||
self._orch = orchestrator
|
||||
self._goal_manager = goal_manager
|
||||
self._planner_id = planner_id
|
||||
def handle_message(self, envelope):
|
||||
if envelope.message.intent == "goal_created" and self._orch and self._planner_id:
|
||||
self._orch.route_message(envelope)
|
||||
return None
|
||||
97
fusionagi/multi_agent/delegation.py
Normal file
97
fusionagi/multi_agent/delegation.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Sub-task delegation: fan-out to sub-agents, fan-in of results.
|
||||
|
||||
Enables hierarchical multi-agent: a supervisor decomposes a task into
|
||||
sub-tasks, delegates to specialized sub-agents, and aggregates results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTask:
|
||||
"""A sub-task to delegate."""
|
||||
|
||||
sub_task_id: str
|
||||
goal: str
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTaskResult:
|
||||
"""Result from a delegated sub-task."""
|
||||
|
||||
sub_task_id: str
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelegationConfig:
|
||||
"""Configuration for delegation behavior."""
|
||||
|
||||
max_parallel: int = 4
|
||||
timeout_seconds: float | None = None
|
||||
fail_fast: bool = False # Stop on first failure
|
||||
|
||||
|
||||
def delegate_sub_tasks(
|
||||
sub_tasks: list[SubTask],
|
||||
delegate_fn: Callable[[SubTask], SubTaskResult],
|
||||
config: DelegationConfig | None = None,
|
||||
) -> list[SubTaskResult]:
|
||||
"""
|
||||
Fan-out: delegate sub-tasks to sub-agents in parallel.
|
||||
|
||||
Args:
|
||||
sub_tasks: List of sub-tasks to delegate.
|
||||
delegate_fn: (SubTask) -> SubTaskResult. Typically calls orchestrator
|
||||
to submit task and route to sub-agent, then wait for completion.
|
||||
config: Delegation behavior.
|
||||
|
||||
Returns:
|
||||
List of SubTaskResult in same order as sub_tasks.
|
||||
"""
|
||||
cfg = config or DelegationConfig()
|
||||
results: list[SubTaskResult] = [None] * len(sub_tasks) # type: ignore
|
||||
index_map = {st.sub_task_id: i for i, st in enumerate(sub_tasks)}
|
||||
|
||||
def run_one(st: SubTask) -> tuple[int, SubTaskResult]:
|
||||
r = delegate_fn(st)
|
||||
return index_map[st.sub_task_id], r
|
||||
|
||||
with ThreadPoolExecutor(max_workers=cfg.max_parallel) as executor:
|
||||
future_to_task = {executor.submit(run_one, st): st for st in sub_tasks}
|
||||
for future in as_completed(future_to_task):
|
||||
idx, result = future.result()
|
||||
results[idx] = result
|
||||
if cfg.fail_fast and not result.success:
|
||||
logger.warning("Delegation fail_fast on failure", extra={"sub_task_id": result.sub_task_id})
|
||||
break
|
||||
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
def aggregate_sub_task_results(
|
||||
results: list[SubTaskResult],
|
||||
aggregator: Callable[[list[SubTaskResult]], Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Fan-in: aggregate sub-task results into a single outcome.
|
||||
|
||||
Args:
|
||||
results: Results from delegate_sub_tasks.
|
||||
aggregator: (results) -> aggregated value.
|
||||
|
||||
Returns:
|
||||
Aggregated result.
|
||||
"""
|
||||
return aggregator(results)
|
||||
144
fusionagi/multi_agent/parallel.py
Normal file
144
fusionagi/multi_agent/parallel.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Parallel step execution: run independent plan steps concurrently.
|
||||
|
||||
Multi-agent acceleration: steps with satisfied dependencies and no mutual
|
||||
dependencies are dispatched in parallel to maximize throughput.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from fusionagi.schemas.plan import Plan
|
||||
from fusionagi.planning import ready_steps, get_step
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelStepResult:
|
||||
"""Result of a single step execution in parallel batch."""
|
||||
|
||||
step_id: str
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
envelope: Any = None # AgentMessageEnvelope from executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteStepsCallback(Protocol):
|
||||
"""Protocol for executing a single step (e.g. via orchestrator)."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
plan: Plan,
|
||||
sender: str = "supervisor",
|
||||
) -> Any:
|
||||
"""Execute one step; return response envelope or result."""
|
||||
...
|
||||
|
||||
|
||||
def execute_steps_parallel(
|
||||
execute_fn: Callable[[str, str, Plan, str], Any],
|
||||
task_id: str,
|
||||
plan: Plan,
|
||||
completed_step_ids: set[str],
|
||||
sender: str = "supervisor",
|
||||
max_workers: int | None = None,
|
||||
) -> list[ParallelStepResult]:
|
||||
"""
|
||||
Execute all ready steps in parallel.
|
||||
|
||||
Args:
|
||||
execute_fn: Function (task_id, step_id, plan, sender) -> response.
|
||||
task_id: Task identifier.
|
||||
plan: The plan containing steps.
|
||||
completed_step_ids: Steps already completed.
|
||||
sender: Sender identity for execute messages.
|
||||
max_workers: Max parallel workers (default: unbounded for ready steps).
|
||||
|
||||
Returns:
|
||||
List of ParallelStepResult, one per step attempted.
|
||||
"""
|
||||
ready = ready_steps(plan, completed_step_ids)
|
||||
if not ready:
|
||||
return []
|
||||
|
||||
results: list[ParallelStepResult] = []
|
||||
workers = max_workers if max_workers and max_workers > 0 else len(ready)
|
||||
|
||||
def run_one(step_id: str) -> ParallelStepResult:
|
||||
try:
|
||||
response = execute_fn(task_id, step_id, plan, sender)
|
||||
if response is None:
|
||||
return ParallelStepResult(step_id=step_id, success=False, error="No response")
|
||||
# Response may be AgentMessageEnvelope with intent step_done/step_failed
|
||||
if hasattr(response, "message"):
|
||||
msg = response.message
|
||||
if msg.intent == "step_done":
|
||||
payload = msg.payload or {}
|
||||
return ParallelStepResult(
|
||||
step_id=step_id,
|
||||
success=True,
|
||||
result=payload.get("result"),
|
||||
envelope=response,
|
||||
)
|
||||
return ParallelStepResult(
|
||||
step_id=step_id,
|
||||
success=False,
|
||||
error=msg.payload.get("error", "Unknown failure") if msg.payload else "Unknown",
|
||||
envelope=response,
|
||||
)
|
||||
return ParallelStepResult(step_id=step_id, success=True, result=response)
|
||||
except Exception as e:
|
||||
logger.exception("Parallel step execution failed", extra={"step_id": step_id})
|
||||
return ParallelStepResult(step_id=step_id, success=False, error=str(e))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_step = {executor.submit(run_one, sid): sid for sid in ready}
|
||||
for future in as_completed(future_to_step):
|
||||
results.append(future.result())
|
||||
|
||||
logger.info(
|
||||
"Parallel step batch completed",
|
||||
extra={"task_id": task_id, "steps": ready, "results": len(results)},
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def execute_steps_parallel_wave(
|
||||
execute_fn: Callable[[str, str, Plan, str], Any],
|
||||
task_id: str,
|
||||
plan: Plan,
|
||||
sender: str = "supervisor",
|
||||
max_workers: int | None = None,
|
||||
) -> list[ParallelStepResult]:
|
||||
"""
|
||||
Execute plan in waves: each wave runs all ready steps in parallel,
|
||||
then advances to the next wave when deps are satisfied.
|
||||
|
||||
Returns combined results from all waves.
|
||||
"""
|
||||
completed: set[str] = set()
|
||||
all_results: list[ParallelStepResult] = []
|
||||
|
||||
while True:
|
||||
batch = execute_steps_parallel(
|
||||
execute_fn, task_id, plan, completed, sender, max_workers
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for r in batch:
|
||||
all_results.append(r)
|
||||
if r.success:
|
||||
completed.add(r.step_id)
|
||||
else:
|
||||
# On failure, stop the wave (caller can retry or handle)
|
||||
logger.warning("Step failed in wave, stopping", extra={"step_id": r.step_id})
|
||||
return all_results
|
||||
|
||||
return all_results
|
||||
190
fusionagi/multi_agent/pool.py
Normal file
190
fusionagi/multi_agent/pool.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Agent pool: load-balanced routing for horizontal scaling.
|
||||
|
||||
Multiple executor (or other) agents behind a single logical endpoint.
|
||||
Supports round-robin, least-busy, and random selection strategies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledAgent:
|
||||
"""An agent in the pool with load tracking."""
|
||||
|
||||
agent_id: str
|
||||
agent: Any # AgentProtocol
|
||||
in_flight: int = 0
|
||||
total_dispatched: int = 0
|
||||
last_used: float = field(default_factory=time.monotonic)
|
||||
|
||||
|
||||
class AgentPool:
|
||||
"""
|
||||
Pool of agents for load-balanced dispatch.
|
||||
|
||||
Strategies:
|
||||
- round_robin: Rotate through agents in order.
|
||||
- least_busy: Prefer agent with lowest in_flight count.
|
||||
- random: Random selection (useful for load spreading).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str = "least_busy",
|
||||
) -> None:
|
||||
self._strategy = strategy
|
||||
self._agents: list[PooledAgent] = []
|
||||
self._round_robin_index = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add(self, agent_id: str, agent: Any) -> None:
|
||||
"""Add an agent to the pool."""
|
||||
with self._lock:
|
||||
if any(p.agent_id == agent_id for p in self._agents):
|
||||
return
|
||||
self._agents.append(PooledAgent(agent_id=agent_id, agent=agent))
|
||||
logger.info("Agent added to pool", extra={"agent_id": agent_id, "pool_size": len(self._agents)})
|
||||
|
||||
def remove(self, agent_id: str) -> bool:
|
||||
"""Remove an agent from the pool."""
|
||||
with self._lock:
|
||||
for i, p in enumerate(self._agents):
|
||||
if p.agent_id == agent_id:
|
||||
self._agents.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _select(self) -> PooledAgent | None:
|
||||
"""Select an agent based on strategy."""
|
||||
with self._lock:
|
||||
if not self._agents:
|
||||
return None
|
||||
|
||||
if self._strategy == "round_robin":
|
||||
idx = self._round_robin_index % len(self._agents)
|
||||
self._round_robin_index += 1
|
||||
return self._agents[idx]
|
||||
|
||||
if self._strategy == "random":
|
||||
return random.choice(self._agents)
|
||||
|
||||
# least_busy (default)
|
||||
return min(self._agents, key=lambda p: (p.in_flight, p.last_used))
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
envelope: AgentMessageEnvelope,
|
||||
on_complete: Callable[[str], None] | None = None,
|
||||
rewrite_recipient: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Dispatch to a pooled agent and return response.
|
||||
|
||||
Tracks in-flight for least_busy; calls on_complete(agent_id) when done
|
||||
if provided (for async cleanup).
|
||||
|
||||
If rewrite_recipient, the envelope's recipient is set to the selected
|
||||
agent's id so the agent receives it correctly.
|
||||
"""
|
||||
pooled = self._select()
|
||||
if not pooled:
|
||||
logger.error("Agent pool empty, cannot dispatch")
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
pooled.in_flight += 1
|
||||
pooled.total_dispatched += 1
|
||||
pooled.last_used = time.monotonic()
|
||||
|
||||
# Rewrite recipient so pooled agent receives correctly
|
||||
if rewrite_recipient:
|
||||
msg = envelope.message
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=msg.sender,
|
||||
recipient=pooled.agent_id,
|
||||
intent=msg.intent,
|
||||
payload=msg.payload,
|
||||
confidence=msg.confidence,
|
||||
uncertainty=msg.uncertainty,
|
||||
timestamp=msg.timestamp,
|
||||
),
|
||||
task_id=envelope.task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = pooled.agent
|
||||
if hasattr(agent, "handle_message"):
|
||||
response = agent.handle_message(envelope)
|
||||
# Ensure response recipient points back to original sender
|
||||
return response
|
||||
return None
|
||||
finally:
|
||||
with self._lock:
|
||||
pooled.in_flight = max(0, pooled.in_flight - 1)
|
||||
if on_complete:
|
||||
on_complete(pooled.agent_id)
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return pool size."""
|
||||
return len(self._agents)
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Return pool statistics for monitoring."""
|
||||
with self._lock:
|
||||
return {
|
||||
"strategy": self._strategy,
|
||||
"size": len(self._agents),
|
||||
"agents": [
|
||||
{
|
||||
"id": p.agent_id,
|
||||
"in_flight": p.in_flight,
|
||||
"total_dispatched": p.total_dispatched,
|
||||
}
|
||||
for p in self._agents
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PooledExecutorRouter:
|
||||
"""
|
||||
Routes execute_step messages to a pool of executors.
|
||||
|
||||
Wraps multiple ExecutorAgent instances; orchestrator or supervisor
|
||||
sends to this router's identity, and it load-balances to the pool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "executor_pool",
|
||||
pool: AgentPool | None = None,
|
||||
) -> None:
|
||||
self.identity = identity
|
||||
self._pool = pool or AgentPool(strategy="least_busy")
|
||||
|
||||
def add_executor(self, executor_id: str, executor: Any) -> None:
|
||||
"""Add an executor to the pool."""
|
||||
self._pool.add(executor_id, executor)
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""Route execute_step to pooled executor; return response."""
|
||||
if envelope.message.intent != "execute_step":
|
||||
return None
|
||||
|
||||
# Rewrite recipient so response comes back to original sender
|
||||
response = self._pool.dispatch(envelope)
|
||||
return response
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Pool statistics."""
|
||||
return self._pool.stats()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user