Initial commit: add .gitignore and README
This commit is contained in:
57
fusionagi/core/__init__.py
Normal file
57
fusionagi/core/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Core orchestration: event bus, state manager, orchestrator, goal manager, scheduler, blockers, persistence."""
|
||||
|
||||
from fusionagi.core.event_bus import EventBus
|
||||
from fusionagi.core.state_manager import StateManager
|
||||
from fusionagi.core.orchestrator import (
|
||||
Orchestrator,
|
||||
InvalidStateTransitionError,
|
||||
VALID_STATE_TRANSITIONS,
|
||||
AgentProtocol,
|
||||
)
|
||||
from fusionagi.core.persistence import StateBackend
|
||||
from fusionagi.core.json_file_backend import JsonFileBackend
|
||||
from fusionagi.core.goal_manager import GoalManager
|
||||
from fusionagi.core.scheduler import Scheduler, SchedulerMode, FallbackMode
|
||||
from fusionagi.core.blockers import BlockersAndCheckpoints
|
||||
from fusionagi.core.head_orchestrator import (
|
||||
run_heads_parallel,
|
||||
run_witness,
|
||||
run_dvadasa,
|
||||
run_second_pass,
|
||||
select_heads_for_complexity,
|
||||
extract_sources_from_head_outputs,
|
||||
MVP_HEADS,
|
||||
ALL_CONTENT_HEADS,
|
||||
)
|
||||
from fusionagi.core.super_big_brain import (
|
||||
run_super_big_brain,
|
||||
SuperBigBrainConfig,
|
||||
SuperBigBrainReasoningProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"StateManager",
|
||||
"Orchestrator",
|
||||
"StateBackend",
|
||||
"JsonFileBackend",
|
||||
"InvalidStateTransitionError",
|
||||
"VALID_STATE_TRANSITIONS",
|
||||
"AgentProtocol",
|
||||
"GoalManager",
|
||||
"Scheduler",
|
||||
"SchedulerMode",
|
||||
"FallbackMode",
|
||||
"BlockersAndCheckpoints",
|
||||
"run_heads_parallel",
|
||||
"run_witness",
|
||||
"run_dvadasa",
|
||||
"run_second_pass",
|
||||
"select_heads_for_complexity",
|
||||
"extract_sources_from_head_outputs",
|
||||
"MVP_HEADS",
|
||||
"ALL_CONTENT_HEADS",
|
||||
"run_super_big_brain",
|
||||
"SuperBigBrainConfig",
|
||||
"SuperBigBrainReasoningProvider",
|
||||
]
|
||||
35
fusionagi/core/blockers.py
Normal file
35
fusionagi/core/blockers.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Blockers and checkpoints for AGI state machine."""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from fusionagi.schemas.goal import Blocker, Checkpoint
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class BlockersAndCheckpoints:
|
||||
"""Tracks blockers (why stuck) and checkpoints (resumable points)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._blockers: dict[str, list[Blocker]] = {}
|
||||
self._checkpoints: dict[str, list[Checkpoint]] = {}
|
||||
|
||||
def add_blocker(self, blocker: Blocker) -> None:
|
||||
self._blockers.setdefault(blocker.task_id, []).append(blocker)
|
||||
logger.info("Blocker added", extra={"task_id": blocker.task_id, "reason": blocker.reason[:80] if blocker.reason else ""})
|
||||
|
||||
def get_blockers(self, task_id: str) -> list[Blocker]:
|
||||
return list(self._blockers.get(task_id, []))
|
||||
|
||||
def clear_blockers(self, task_id: str) -> None:
|
||||
self._blockers.pop(task_id, None)
|
||||
|
||||
def add_checkpoint(self, checkpoint: Checkpoint) -> None:
|
||||
self._checkpoints.setdefault(checkpoint.task_id, []).append(checkpoint)
|
||||
logger.debug("Checkpoint added", extra={"task_id": checkpoint.task_id})
|
||||
|
||||
def get_latest_checkpoint(self, task_id: str) -> Checkpoint | None:
|
||||
lst = self._checkpoints.get(task_id, [])
|
||||
return lst[-1] if lst else None
|
||||
|
||||
def list_checkpoints(self, task_id: str) -> list[Checkpoint]:
|
||||
return list(self._checkpoints.get(task_id, []))
|
||||
77
fusionagi/core/event_bus.py
Normal file
77
fusionagi/core/event_bus.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""In-process pub/sub event bus for task lifecycle and agent messages."""
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi._time import utc_now_iso
|
||||
|
||||
# Type for event handlers: (event_type, payload) -> None
|
||||
EventHandler = Callable[[str, dict[str, Any]], None]
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""Simple in-process event bus: event type -> list of handlers; optional event history."""
|
||||
|
||||
def __init__(self, history_size: int = 0) -> None:
|
||||
"""
|
||||
Initialize event bus.
|
||||
|
||||
Args:
|
||||
history_size: If > 0, keep the last N events for get_recent_events().
|
||||
"""
|
||||
self._handlers: dict[str, list[EventHandler]] = defaultdict(list)
|
||||
self._history_size = max(0, history_size)
|
||||
self._history: deque[dict[str, Any]] = deque(maxlen=self._history_size) if self._history_size else deque()
|
||||
|
||||
def subscribe(self, event_type: str, handler: EventHandler) -> None:
|
||||
"""Register a handler for an event type."""
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
def unsubscribe(self, event_type: str, handler: EventHandler) -> None:
|
||||
"""Remove a handler for an event type."""
|
||||
if event_type in self._handlers:
|
||||
try:
|
||||
self._handlers[event_type].remove(handler)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def publish(self, event_type: str, payload: dict[str, Any] | None = None) -> None:
|
||||
"""Publish an event; all registered handlers are invoked."""
|
||||
payload = payload or {}
|
||||
if self._history_size > 0:
|
||||
self._history.append({
|
||||
"event_type": event_type,
|
||||
"payload": dict(payload),
|
||||
"timestamp": utc_now_iso(),
|
||||
})
|
||||
task_id = payload.get("task_id", "")
|
||||
logger.debug(
|
||||
"Event published",
|
||||
extra={"event_type": event_type, "task_id": task_id},
|
||||
)
|
||||
for h in self._handlers[event_type][:]:
|
||||
try:
|
||||
h(event_type, payload)
|
||||
except Exception:
|
||||
# Log and continue so one handler failure doesn't block others
|
||||
logger.exception(
|
||||
"Event handler failed",
|
||||
extra={"event_type": event_type},
|
||||
)
|
||||
|
||||
def get_recent_events(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""Return the most recent events (oldest first in slice). Only available if history_size > 0."""
|
||||
if self._history_size == 0:
|
||||
return []
|
||||
events = list(self._history)
|
||||
return events[-limit:] if limit else events
|
||||
|
||||
def clear(self, event_type: str | None = None) -> None:
|
||||
"""Clear handlers for one event type or all; clear history when clearing all."""
|
||||
if event_type is None:
|
||||
self._handlers.clear()
|
||||
if self._history:
|
||||
self._history.clear()
|
||||
elif event_type in self._handlers:
|
||||
del self._handlers[event_type]
|
||||
82
fusionagi/core/goal_manager.py
Normal file
82
fusionagi/core/goal_manager.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Goal manager: objectives, priorities, constraints, time/compute budget for AGI."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.goal import Goal, GoalBudget, GoalStatus
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class GoalManager:
|
||||
"""
|
||||
Manages goals with budgets. Tracks time/compute and can signal
|
||||
when a goal is over budget (abort or degrade).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._goals: dict[str, Goal] = {}
|
||||
self._budget_used: dict[str, dict[str, float]] = {} # goal_id -> {time_used, compute_used}
|
||||
|
||||
def add_goal(self, goal: Goal) -> None:
|
||||
"""Register a goal."""
|
||||
self._goals[goal.goal_id] = goal
|
||||
self._budget_used[goal.goal_id] = {"time_used": 0.0, "compute_used": 0.0}
|
||||
logger.info("Goal added", extra={"goal_id": goal.goal_id, "objective": goal.objective[:80]})
|
||||
|
||||
def get_goal(self, goal_id: str) -> Goal | None:
|
||||
"""Return goal by id or None."""
|
||||
return self._goals.get(goal_id)
|
||||
|
||||
def set_status(self, goal_id: str, status: GoalStatus) -> None:
|
||||
"""Update goal status."""
|
||||
g = self._goals.get(goal_id)
|
||||
if g:
|
||||
self._goals[goal_id] = g.model_copy(update={"status": status})
|
||||
logger.debug("Goal status set", extra={"goal_id": goal_id, "status": status.value})
|
||||
|
||||
def record_time(self, goal_id: str, seconds: float) -> None:
|
||||
"""Record elapsed time for a goal; check budget."""
|
||||
if goal_id not in self._budget_used:
|
||||
self._budget_used[goal_id] = {"time_used": 0.0, "compute_used": 0.0}
|
||||
self._budget_used[goal_id]["time_used"] += seconds
|
||||
self._check_budget(goal_id)
|
||||
|
||||
def record_compute(self, goal_id: str, units: float) -> None:
|
||||
"""Record compute units for a goal; check budget."""
|
||||
if goal_id not in self._budget_used:
|
||||
self._budget_used[goal_id] = {"time_used": 0.0, "compute_used": 0.0}
|
||||
self._budget_used[goal_id]["compute_used"] += units
|
||||
self._check_budget(goal_id)
|
||||
|
||||
def _check_budget(self, goal_id: str) -> None:
|
||||
"""If over budget, set goal to blocked/suspended and log."""
|
||||
g = self._goals.get(goal_id)
|
||||
if not g or not g.budget:
|
||||
return
|
||||
used = self._budget_used.get(goal_id, {})
|
||||
over = False
|
||||
if g.budget.time_seconds is not None and used.get("time_used", 0) >= g.budget.time_seconds:
|
||||
over = True
|
||||
if g.budget.compute_budget is not None and used.get("compute_used", 0) >= g.budget.compute_budget:
|
||||
over = True
|
||||
if over:
|
||||
self.set_status(goal_id, GoalStatus.BLOCKED)
|
||||
logger.warning("Goal over budget", extra={"goal_id": goal_id, "used": used})
|
||||
|
||||
def is_over_budget(self, goal_id: str) -> bool:
|
||||
"""Return True if goal has exceeded its budget."""
|
||||
g = self._goals.get(goal_id)
|
||||
if not g or not g.budget:
|
||||
return False
|
||||
used = self._budget_used.get(goal_id, {})
|
||||
if g.budget.time_seconds is not None and used.get("time_used", 0) >= g.budget.time_seconds:
|
||||
return True
|
||||
if g.budget.compute_budget is not None and used.get("compute_used", 0) >= g.budget.compute_budget:
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_goals(self, status: GoalStatus | None = None) -> list[Goal]:
|
||||
"""Return goals, optionally filtered by status."""
|
||||
goals = list(self._goals.values())
|
||||
if status is not None:
|
||||
goals = [g for g in goals if g.status == status]
|
||||
return goals
|
||||
339
fusionagi/core/head_orchestrator.py
Normal file
339
fusionagi/core/head_orchestrator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Dvādaśa head orchestrator: parallel head dispatch, Witness coordination, second-pass."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.orchestrator import Orchestrator
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput
|
||||
from fusionagi.schemas.witness import FinalResponse
|
||||
from fusionagi.schemas.commands import ParsedCommand, UserIntent
|
||||
from fusionagi._logger import logger
|
||||
|
||||
# MVP: 5 heads. Full: 11.
|
||||
MVP_HEADS: list[HeadId] = [
|
||||
HeadId.LOGIC,
|
||||
HeadId.RESEARCH,
|
||||
HeadId.STRATEGY,
|
||||
HeadId.SECURITY,
|
||||
HeadId.SAFETY,
|
||||
]
|
||||
|
||||
ALL_CONTENT_HEADS: list[HeadId] = [h for h in HeadId if h != HeadId.WITNESS]
|
||||
|
||||
# Heads for second-pass when risk/conflict/security
|
||||
SECOND_PASS_HEADS: list[HeadId] = [HeadId.SECURITY, HeadId.SAFETY, HeadId.LOGIC]
|
||||
|
||||
# Thresholds for automatic second-pass
|
||||
SECOND_PASS_CONFIG: dict[str, Any] = {
|
||||
"min_confidence": 0.5,
|
||||
"max_disputed": 3,
|
||||
"security_keywords": ("security", "risk", "threat", "vulnerability"),
|
||||
}
|
||||
|
||||
|
||||
def run_heads_parallel(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
user_prompt: str,
|
||||
head_ids: list[HeadId] | None = None,
|
||||
sender: str = "head_orchestrator",
|
||||
timeout_per_head: float = 60.0,
|
||||
min_heads_ratio: float | None = 0.6,
|
||||
) -> list[HeadOutput]:
|
||||
"""
|
||||
Dispatch head_request to multiple heads in parallel; collect HeadOutput.
|
||||
|
||||
Args:
|
||||
orchestrator: Orchestrator with registered head agents.
|
||||
task_id: Task identifier.
|
||||
user_prompt: User's prompt/question.
|
||||
head_ids: Heads to run (default: MVP_HEADS).
|
||||
sender: Sender identity for messages.
|
||||
timeout_per_head: Max seconds per head.
|
||||
min_heads_ratio: Return early once this fraction of heads respond (0.6 = 60%).
|
||||
None = wait for all heads. Reduces latency when some heads are slow.
|
||||
|
||||
Returns:
|
||||
List of HeadOutput (may be partial on timeout/failure).
|
||||
"""
|
||||
heads = head_ids or MVP_HEADS
|
||||
heads = [h for h in heads if h != HeadId.WITNESS]
|
||||
if not heads:
|
||||
return []
|
||||
|
||||
envelopes = [
|
||||
AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=sender,
|
||||
recipient=hid.value,
|
||||
intent="head_request",
|
||||
payload={"prompt": user_prompt},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
for hid in heads
|
||||
]
|
||||
|
||||
results: list[HeadOutput] = []
|
||||
min_required = (
|
||||
max(1, math.ceil(len(heads) * min_heads_ratio))
|
||||
if min_heads_ratio is not None
|
||||
else len(heads)
|
||||
)
|
||||
|
||||
def run_one(env: AgentMessageEnvelope) -> HeadOutput | None:
|
||||
resp = orchestrator.route_message_return(env)
|
||||
if resp is None or resp.message.intent != "head_output":
|
||||
return None
|
||||
payload = resp.message.payload or {}
|
||||
ho = payload.get("head_output")
|
||||
if not isinstance(ho, dict):
|
||||
return None
|
||||
try:
|
||||
return HeadOutput.model_validate(ho)
|
||||
except Exception as e:
|
||||
logger.warning("HeadOutput parse failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(heads)) as ex:
|
||||
future_to_env = {ex.submit(run_one, env): env for env in envelopes}
|
||||
for future in as_completed(future_to_env, timeout=timeout_per_head * len(heads) + 5):
|
||||
try:
|
||||
out = future.result(timeout=1)
|
||||
if out is not None:
|
||||
results.append(out)
|
||||
if len(results) >= min_required:
|
||||
logger.info(
|
||||
"Early exit: sufficient heads responded",
|
||||
extra={"responded": len(results), "required": min_required},
|
||||
)
|
||||
break
|
||||
except FuturesTimeoutError:
|
||||
env = future_to_env[future]
|
||||
logger.warning("Head timeout", extra={"recipient": env.message.recipient})
|
||||
except Exception as e:
|
||||
logger.exception("Head execution failed", extra={"error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_witness(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
head_outputs: list[HeadOutput],
|
||||
user_prompt: str,
|
||||
sender: str = "head_orchestrator",
|
||||
) -> FinalResponse | None:
|
||||
"""
|
||||
Route head outputs to Witness; return FinalResponse.
|
||||
"""
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=sender,
|
||||
recipient=HeadId.WITNESS.value,
|
||||
intent="witness_request",
|
||||
payload={
|
||||
"head_outputs": [h.model_dump() for h in head_outputs],
|
||||
"prompt": user_prompt,
|
||||
},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
resp = orchestrator.route_message_return(envelope)
|
||||
if resp is None or resp.message.intent != "witness_output":
|
||||
return None
|
||||
payload = resp.message.payload or {}
|
||||
fr = payload.get("final_response")
|
||||
if not isinstance(fr, dict):
|
||||
return None
|
||||
try:
|
||||
return FinalResponse.model_validate(fr)
|
||||
except Exception as e:
|
||||
logger.warning("FinalResponse parse failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
|
||||
def run_second_pass(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
user_prompt: str,
|
||||
initial_outputs: list[HeadOutput],
|
||||
head_ids: list[HeadId] | None = None,
|
||||
timeout_per_head: float = 60.0,
|
||||
) -> list[HeadOutput]:
|
||||
"""
|
||||
Run second-pass heads (Security, Safety, Logic) and merge with initial outputs.
|
||||
Replaces outputs from second-pass heads with new ones.
|
||||
"""
|
||||
heads = head_ids or SECOND_PASS_HEADS
|
||||
heads = [h for h in heads if h != HeadId.WITNESS]
|
||||
if not heads:
|
||||
return initial_outputs
|
||||
|
||||
second_outputs = run_heads_parallel(
|
||||
orchestrator,
|
||||
task_id,
|
||||
user_prompt,
|
||||
head_ids=heads,
|
||||
timeout_per_head=timeout_per_head,
|
||||
)
|
||||
|
||||
by_head: dict[HeadId, HeadOutput] = {o.head_id: o for o in initial_outputs}
|
||||
for o in second_outputs:
|
||||
by_head[o.head_id] = o
|
||||
return list(by_head.values())
|
||||
|
||||
|
||||
def _should_run_second_pass(
|
||||
final: FinalResponse,
|
||||
force: bool = False,
|
||||
second_pass_config: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Check if second-pass should run based on transparency report."""
|
||||
if force:
|
||||
return True
|
||||
cfg = {**SECOND_PASS_CONFIG, **(second_pass_config or {})}
|
||||
am = final.transparency_report.agreement_map
|
||||
if am.confidence_score < cfg.get("min_confidence", 0.5):
|
||||
return True
|
||||
if len(am.disputed_claims) > cfg.get("max_disputed", 3):
|
||||
return True
|
||||
sr = (final.transparency_report.safety_report or "").lower()
|
||||
if any(kw in sr for kw in cfg.get("security_keywords", ())):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def run_dvadasa(
|
||||
orchestrator: Orchestrator,
|
||||
task_id: str,
|
||||
user_prompt: str,
|
||||
parsed: ParsedCommand | None = None,
|
||||
head_ids: list[HeadId] | None = None,
|
||||
timeout_per_head: float = 60.0,
|
||||
event_bus: Any | None = None,
|
||||
force_second_pass: bool = False,
|
||||
return_head_outputs: bool = False,
|
||||
second_pass_config: dict[str, Any] | None = None,
|
||||
min_heads_ratio: float | None = 0.6,
|
||||
) -> FinalResponse | tuple[FinalResponse, list[HeadOutput]] | tuple[None, list[HeadOutput]] | None:
|
||||
"""
|
||||
Full Dvādaśa flow: run heads in parallel, then Witness.
|
||||
|
||||
Args:
|
||||
orchestrator: Orchestrator with heads and witness registered.
|
||||
task_id: Task identifier.
|
||||
user_prompt: User's prompt (or use parsed.cleaned_prompt when HEAD_STRATEGY).
|
||||
parsed: Optional ParsedCommand from parse_user_input.
|
||||
head_ids: Override heads to run (e.g. single head for HEAD_STRATEGY).
|
||||
timeout_per_head: Max seconds per head.
|
||||
event_bus: Optional EventBus to publish dvadasa_complete.
|
||||
second_pass_config: Override SECOND_PASS_CONFIG (min_confidence, max_disputed, etc).
|
||||
min_heads_ratio: Early exit once this fraction of heads respond; None = wait all.
|
||||
|
||||
Returns:
|
||||
FinalResponse or None on failure.
|
||||
"""
|
||||
prompt = user_prompt
|
||||
heads = head_ids
|
||||
|
||||
if parsed:
|
||||
if parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id and parsed.cleaned_prompt:
|
||||
prompt = parsed.cleaned_prompt
|
||||
heads = [parsed.head_id]
|
||||
elif parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id:
|
||||
heads = [parsed.head_id]
|
||||
|
||||
if heads is None:
|
||||
heads = select_heads_for_complexity(prompt)
|
||||
|
||||
head_outputs = run_heads_parallel(
|
||||
orchestrator,
|
||||
task_id,
|
||||
prompt,
|
||||
head_ids=heads,
|
||||
timeout_per_head=timeout_per_head,
|
||||
)
|
||||
if not head_outputs:
|
||||
logger.warning("No head outputs; cannot run Witness")
|
||||
return (None, []) if return_head_outputs else None
|
||||
|
||||
final = run_witness(orchestrator, task_id, head_outputs, prompt)
|
||||
|
||||
if final and (
|
||||
force_second_pass
|
||||
or _should_run_second_pass(
|
||||
final, force=force_second_pass, second_pass_config=second_pass_config
|
||||
)
|
||||
):
|
||||
head_outputs = run_second_pass(
|
||||
orchestrator,
|
||||
task_id,
|
||||
prompt,
|
||||
head_outputs,
|
||||
timeout_per_head=timeout_per_head,
|
||||
)
|
||||
final = run_witness(orchestrator, task_id, head_outputs, prompt)
|
||||
|
||||
if final and event_bus:
|
||||
try:
|
||||
event_bus.publish(
|
||||
"dvadasa_complete",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"final_response": final.model_dump(),
|
||||
"head_count": len(head_outputs),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish dvadasa_complete", extra={"error": str(e)})
|
||||
|
||||
if return_head_outputs:
|
||||
return (final, head_outputs)
|
||||
return final
|
||||
|
||||
|
||||
def extract_sources_from_head_outputs(head_outputs: list[HeadOutput]) -> list[dict[str, Any]]:
|
||||
"""Extract citations from head outputs for SOURCES command."""
|
||||
sources: list[dict[str, Any]] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for ho in head_outputs:
|
||||
for claim in ho.claims:
|
||||
for ev in claim.evidence:
|
||||
key = (ho.head_id.value, ev.source_id or "")
|
||||
if key in seen or not ev.source_id:
|
||||
continue
|
||||
seen.add(key)
|
||||
sources.append({
|
||||
"head_id": ho.head_id.value,
|
||||
"source_id": ev.source_id,
|
||||
"excerpt": ev.excerpt or "",
|
||||
"confidence": ev.confidence,
|
||||
})
|
||||
return sources
|
||||
|
||||
|
||||
def select_heads_for_complexity(
|
||||
prompt: str,
|
||||
mvp_heads: list[HeadId] = MVP_HEADS,
|
||||
all_heads: list[HeadId] | None = None,
|
||||
) -> list[HeadId]:
|
||||
"""
|
||||
Dynamic routing: simple prompts use fewer heads.
|
||||
Heuristic: long prompt or keywords => all heads.
|
||||
"""
|
||||
all_heads = all_heads or ALL_CONTENT_HEADS
|
||||
prompt_lower = prompt.lower()
|
||||
complex_keywords = (
|
||||
"security", "risk", "architecture", "scalability", "compliance",
|
||||
"critical", "production", "audit", "privacy", "sensitive",
|
||||
)
|
||||
if len(prompt.split()) > 50 or any(kw in prompt_lower for kw in complex_keywords):
|
||||
return all_heads
|
||||
return mvp_heads
|
||||
69
fusionagi/core/json_file_backend.py
Normal file
69
fusionagi/core/json_file_backend.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""JSON file persistence backend for StateManager."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState
|
||||
from fusionagi.core.persistence import StateBackend
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class JsonFileBackend(StateBackend):
|
||||
"""
|
||||
StateBackend that persists tasks and traces to a JSON file.
|
||||
|
||||
Use with StateManager(backend=JsonFileBackend(path="state.json")).
|
||||
File is created on first write; directory must exist or be creatable.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str | Path) -> None:
|
||||
self._path = Path(path)
|
||||
self._tasks: dict[str, dict[str, Any]] = {}
|
||||
self._traces: dict[str, list[dict[str, Any]]] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
if not self._path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
self._tasks = data.get("tasks", {})
|
||||
self._traces = data.get("traces", {})
|
||||
except Exception as e:
|
||||
logger.warning("JsonFileBackend load failed", extra={"path": str(self._path), "error": str(e)})
|
||||
|
||||
def _save(self) -> None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = {"tasks": self._tasks, "traces": self._traces}
|
||||
self._path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
raw = self._tasks.get(task_id)
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return Task.model_validate(raw)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def set_task(self, task: Task) -> None:
|
||||
self._tasks[task.task_id] = task.model_dump(mode="json")
|
||||
self._save()
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
task = self.get_task(task_id)
|
||||
return task.state if task else None
|
||||
|
||||
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||||
task = self.get_task(task_id)
|
||||
if task:
|
||||
updated = task.model_copy(update={"state": state})
|
||||
self.set_task(updated)
|
||||
|
||||
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||||
self._traces.setdefault(task_id, []).append(entry)
|
||||
self._save()
|
||||
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||||
return list(self._traces.get(task_id, []))
|
||||
310
fusionagi/core/orchestrator.py
Normal file
310
fusionagi/core/orchestrator.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Orchestrator: task lifecycle, agent registry, wiring to event bus and state."""
|
||||
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState, TaskPriority, VALID_TASK_TRANSITIONS
|
||||
from fusionagi.schemas.messages import AgentMessageEnvelope
|
||||
|
||||
from fusionagi.core.event_bus import EventBus
|
||||
from fusionagi.core.state_manager import StateManager
|
||||
from fusionagi._logger import logger
|
||||
|
||||
# Single source of truth: re-export from schemas for backward compatibility
|
||||
VALID_STATE_TRANSITIONS = VALID_TASK_TRANSITIONS
|
||||
|
||||
|
||||
class InvalidStateTransitionError(Exception):
|
||||
"""Raised when an invalid state transition is attempted."""
|
||||
|
||||
def __init__(self, task_id: str, from_state: TaskState, to_state: TaskState) -> None:
|
||||
self.task_id = task_id
|
||||
self.from_state = from_state
|
||||
self.to_state = to_state
|
||||
super().__init__(
|
||||
f"Invalid state transition for task {task_id}: {from_state.value} -> {to_state.value}"
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentProtocol(Protocol):
|
||||
"""Protocol for agents that can handle messages."""
|
||||
|
||||
identity: str
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""Handle an incoming message and optionally return a response."""
|
||||
...
|
||||
|
||||
|
||||
class TaskGraphEntry(BaseModel):
|
||||
"""Per-task plan/metadata storage (plan cache)."""
|
||||
|
||||
plan: dict[str, Any] | None = Field(default=None, description="Stored plan for the task")
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
"""
|
||||
Global task lifecycle and agent coordination; holds task plans, event bus, state, agent registry.
|
||||
|
||||
Task state lifecycle: submit_task creates PENDING. Callers/supervisors must call set_task_state
|
||||
to transition to ACTIVE, COMPLETED, FAILED, or CANCELLED. The orchestrator validates state
|
||||
transitions according to VALID_STATE_TRANSITIONS.
|
||||
|
||||
Valid transitions:
|
||||
PENDING -> ACTIVE, CANCELLED
|
||||
ACTIVE -> COMPLETED, FAILED, CANCELLED
|
||||
FAILED -> PENDING (retry), CANCELLED
|
||||
COMPLETED -> (terminal)
|
||||
CANCELLED -> (terminal)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_bus: EventBus,
|
||||
state_manager: StateManager,
|
||||
validate_transitions: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the orchestrator.
|
||||
|
||||
Args:
|
||||
event_bus: Event bus for publishing events.
|
||||
state_manager: State manager for task state.
|
||||
validate_transitions: If True, validate state transitions (default True).
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._state = state_manager
|
||||
self._validate_transitions = validate_transitions
|
||||
self._agents: dict[str, AgentProtocol | Any] = {} # agent_id -> agent instance
|
||||
self._sub_agents: dict[str, list[str]] = {} # parent_id -> [child_id]
|
||||
self._task_plans: dict[str, TaskGraphEntry] = {} # task_id -> plan/metadata per task
|
||||
self._async_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="orch_async")
|
||||
|
||||
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||
"""Register an agent by id for routing and assignment."""
|
||||
self._agents[agent_id] = agent
|
||||
logger.info("Agent registered", extra={"agent_id": agent_id})
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""Remove an agent from the registry and from any parent's sub-agent list."""
|
||||
self._agents.pop(agent_id, None)
|
||||
self._sub_agents.pop(agent_id, None)
|
||||
for parent_id, children in list(self._sub_agents.items()):
|
||||
if agent_id in children:
|
||||
self._sub_agents[parent_id] = [c for c in children if c != agent_id]
|
||||
logger.info("Agent unregistered", extra={"agent_id": agent_id})
|
||||
|
||||
def register_sub_agent(self, parent_id: str, child_id: str, agent: Any) -> None:
|
||||
"""Register a sub-agent under a parent; child can be delegated sub-tasks."""
|
||||
self._agents[child_id] = agent
|
||||
self._sub_agents.setdefault(parent_id, []).append(child_id)
|
||||
logger.info("Sub-agent registered", extra={"parent_id": parent_id, "child_id": child_id})
|
||||
|
||||
def get_sub_agents(self, parent_id: str) -> list[str]:
|
||||
"""Return list of child agent ids for a parent."""
|
||||
return list(self._sub_agents.get(parent_id, []))
|
||||
|
||||
def get_agent(self, agent_id: str) -> Any | None:
|
||||
"""Return registered agent by id or None."""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shut down the async executor used for route_message_async. Call when the orchestrator is no longer needed."""
|
||||
self._async_executor.shutdown(wait=wait)
|
||||
logger.debug("Orchestrator async executor shut down", extra={"wait": wait})
|
||||
|
||||
def submit_task(
|
||||
self,
|
||||
goal: str,
|
||||
constraints: list[str] | None = None,
|
||||
priority: TaskPriority = TaskPriority.NORMAL,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Create a task and publish task_created; returns task_id."""
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
goal=goal,
|
||||
constraints=constraints or [],
|
||||
priority=priority,
|
||||
state=TaskState.PENDING,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._state.set_task(task)
|
||||
self._task_plans[task_id] = TaskGraphEntry()
|
||||
logger.info(
|
||||
"Task created",
|
||||
extra={"task_id": task_id, "goal": goal[:200] if goal else ""},
|
||||
)
|
||||
self._event_bus.publish(
|
||||
"task_created",
|
||||
{"task_id": task_id, "goal": goal, "constraints": task.constraints},
|
||||
)
|
||||
return task_id
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
"""Return current state of a task or None if unknown."""
|
||||
return self._state.get_task_state(task_id)
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
"""Return full task or None."""
|
||||
return self._state.get_task(task_id)
|
||||
|
||||
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None:
|
||||
"""Store plan in task plans for a task."""
|
||||
if task_id in self._task_plans:
|
||||
self._task_plans[task_id].plan = plan
|
||||
|
||||
def get_task_plan(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Return stored plan for a task or None."""
|
||||
entry = self._task_plans.get(task_id)
|
||||
return entry.plan if entry else None
|
||||
|
||||
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None:
|
||||
"""
|
||||
Update task state with transition validation.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier.
|
||||
state: The new state to transition to.
|
||||
force: If True, skip transition validation (use with caution).
|
||||
|
||||
Raises:
|
||||
InvalidStateTransitionError: If the transition is not allowed and force=False.
|
||||
ValueError: If task_id is unknown.
|
||||
"""
|
||||
current_state = self._state.get_task_state(task_id)
|
||||
if current_state is None:
|
||||
raise ValueError(f"Unknown task: {task_id}")
|
||||
|
||||
if not force and self._validate_transitions:
|
||||
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
|
||||
if state not in allowed and state != current_state:
|
||||
raise InvalidStateTransitionError(task_id, current_state, state)
|
||||
|
||||
self._state.set_task_state(task_id, state)
|
||||
logger.debug(
|
||||
"Task state set",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"from_state": current_state.value,
|
||||
"to_state": state.value,
|
||||
},
|
||||
)
|
||||
self._event_bus.publish(
|
||||
"task_state_changed",
|
||||
{"task_id": task_id, "from_state": current_state.value, "to_state": state.value},
|
||||
)
|
||||
|
||||
def can_transition(self, task_id: str, state: TaskState) -> bool:
|
||||
"""Check if a state transition is valid without performing it."""
|
||||
current_state = self._state.get_task_state(task_id)
|
||||
if current_state is None:
|
||||
return False
|
||||
if state == current_state:
|
||||
return True
|
||||
allowed = VALID_TASK_TRANSITIONS.get(current_state, set())
|
||||
return state in allowed
|
||||
|
||||
def route_message(self, envelope: AgentMessageEnvelope) -> None:
|
||||
"""
|
||||
Deliver an envelope to the recipient agent and publish message_received.
|
||||
Does not route the agent's response; use route_message_return to get and optionally
|
||||
re-route the response envelope.
|
||||
"""
|
||||
recipient = envelope.message.recipient
|
||||
intent = envelope.message.intent
|
||||
task_id = envelope.task_id or ""
|
||||
logger.info(
|
||||
"Message routed",
|
||||
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
|
||||
)
|
||||
agent = self._agents.get(recipient)
|
||||
self._event_bus.publish(
|
||||
"message_received",
|
||||
{
|
||||
"task_id": envelope.task_id,
|
||||
"recipient": recipient,
|
||||
"intent": intent,
|
||||
},
|
||||
)
|
||||
if agent is not None and hasattr(agent, "handle_message"):
|
||||
agent.handle_message(envelope)
|
||||
|
||||
def route_message_return(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""
|
||||
Deliver an envelope to the recipient agent and return the response envelope, if any.
|
||||
Use this when the caller needs to handle or re-route the agent's response.
|
||||
"""
|
||||
recipient = envelope.message.recipient
|
||||
intent = envelope.message.intent
|
||||
task_id = envelope.task_id or ""
|
||||
logger.info(
|
||||
"Message routed",
|
||||
extra={"task_id": task_id, "recipient": recipient, "intent": intent},
|
||||
)
|
||||
agent = self._agents.get(recipient)
|
||||
self._event_bus.publish(
|
||||
"message_received",
|
||||
{
|
||||
"task_id": envelope.task_id,
|
||||
"recipient": recipient,
|
||||
"intent": intent,
|
||||
},
|
||||
)
|
||||
if agent is not None and hasattr(agent, "handle_message"):
|
||||
return agent.handle_message(envelope)
|
||||
return None
|
||||
|
||||
def route_messages_batch(
|
||||
self,
|
||||
envelopes: list[AgentMessageEnvelope],
|
||||
) -> list[AgentMessageEnvelope | None]:
|
||||
"""
|
||||
Route multiple messages; return responses in same order.
|
||||
Uses concurrent execution for parallel dispatch.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
results: list[AgentMessageEnvelope | None] = [None] * len(envelopes)
|
||||
|
||||
def route_one(i: int, env: AgentMessageEnvelope) -> tuple[int, AgentMessageEnvelope | None]:
|
||||
return i, self.route_message_return(env)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(len(envelopes), 32)) as ex:
|
||||
futures = [ex.submit(route_one, i, env) for i, env in enumerate(envelopes)]
|
||||
for fut in as_completed(futures):
|
||||
idx, resp = fut.result()
|
||||
results[idx] = resp
|
||||
|
||||
return results
|
||||
|
||||
def route_message_async(
|
||||
self,
|
||||
envelope: AgentMessageEnvelope,
|
||||
callback: Callable[[AgentMessageEnvelope | None], None] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Route message in background; optionally invoke callback with response.
|
||||
Returns Future for non-blocking await.
|
||||
"""
|
||||
from concurrent import futures
|
||||
|
||||
def run() -> AgentMessageEnvelope | None:
|
||||
return self.route_message_return(envelope)
|
||||
|
||||
future = self._async_executor.submit(run)
|
||||
if callback:
|
||||
|
||||
def done(f: futures.Future) -> None:
|
||||
try:
|
||||
callback(f.result())
|
||||
except Exception:
|
||||
logger.exception("Async route callback failed")
|
||||
|
||||
future.add_done_callback(done)
|
||||
return future
|
||||
44
fusionagi/core/persistence.py
Normal file
44
fusionagi/core/persistence.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Optional persistence interface for state manager; in-memory is default."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState
|
||||
|
||||
|
||||
class StateBackend(ABC):
|
||||
"""
|
||||
Abstract backend for task state and traces; replace StateManager internals for persistence.
|
||||
Any backend used to replace StateManager storage must implement get_task_state and set_task_state
|
||||
in addition to task and trace methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
"""Load task by id."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_task(self, task: Task) -> None:
|
||||
"""Save task."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
"""Return current task state or None if task unknown."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||||
"""Update task state; creates no task if missing."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||||
"""Append trace entry."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||||
"""Load trace for task."""
|
||||
...
|
||||
89
fusionagi/core/scheduler.py
Normal file
89
fusionagi/core/scheduler.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Scheduler: think vs act, tool selection, retry logic, fallback modes for AGI."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class SchedulerMode(str, Enum):
|
||||
"""Whether to think (reason) or act (tool) next."""
|
||||
|
||||
THINK = "think"
|
||||
ACT = "act"
|
||||
|
||||
|
||||
class FallbackMode(str, Enum):
|
||||
"""Fallback when primary path fails."""
|
||||
|
||||
RETRY = "retry"
|
||||
SIMPLIFY_PLAN = "simplify_plan"
|
||||
HUMAN_HANDOFF = "human_handoff"
|
||||
ABORT = "abort"
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Decides think vs act, tool selection policy, retry/backoff, fallback.
|
||||
Callers (e.g. Supervisor) query next_action() and record outcomes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_mode: SchedulerMode = SchedulerMode.ACT,
|
||||
max_retries_per_step: int = 2,
|
||||
fallback_sequence: list[FallbackMode] | None = None,
|
||||
) -> None:
|
||||
self._default_mode = default_mode
|
||||
self._max_retries = max_retries_per_step
|
||||
self._fallback_sequence = fallback_sequence or [
|
||||
FallbackMode.RETRY,
|
||||
FallbackMode.SIMPLIFY_PLAN,
|
||||
FallbackMode.HUMAN_HANDOFF,
|
||||
FallbackMode.ABORT,
|
||||
]
|
||||
self._retry_counts: dict[str, int] = {} # step_key -> count
|
||||
self._fallback_index: dict[str, int] = {} # task_id -> index into fallback_sequence
|
||||
|
||||
def next_mode(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> SchedulerMode:
|
||||
"""
|
||||
Return whether to think (reason more) or act (execute step).
|
||||
Override via context["force_think"] or context["force_act"].
|
||||
"""
|
||||
if context:
|
||||
if context.get("force_think"):
|
||||
return SchedulerMode.THINK
|
||||
if context.get("force_act"):
|
||||
return SchedulerMode.ACT
|
||||
return self._default_mode
|
||||
|
||||
def should_retry(self, task_id: str, step_id: str) -> bool:
|
||||
"""Return True if step should be retried (under max_retries)."""
|
||||
key = f"{task_id}:{step_id}"
|
||||
count = self._retry_counts.get(key, 0)
|
||||
return count < self._max_retries
|
||||
|
||||
def record_retry(self, task_id: str, step_id: str) -> None:
|
||||
"""Increment retry count for step."""
|
||||
key = f"{task_id}:{step_id}"
|
||||
self._retry_counts[key] = self._retry_counts.get(key, 0) + 1
|
||||
logger.debug("Scheduler recorded retry", extra={"task_id": task_id, "step_id": step_id})
|
||||
|
||||
def next_fallback(self, task_id: str) -> FallbackMode | None:
|
||||
"""Return next fallback mode for task, or None if exhausted."""
|
||||
idx = self._fallback_index.get(task_id, 0)
|
||||
if idx >= len(self._fallback_sequence):
|
||||
return None
|
||||
mode = self._fallback_sequence[idx]
|
||||
self._fallback_index[task_id] = idx + 1
|
||||
logger.info("Scheduler fallback", extra={"task_id": task_id, "fallback": mode.value})
|
||||
return mode
|
||||
|
||||
def reset_fallback(self, task_id: str) -> None:
|
||||
"""Reset fallback index for task (e.g. after success)."""
|
||||
self._fallback_index.pop(task_id, None)
|
||||
111
fusionagi/core/state_manager.py
Normal file
111
fusionagi/core/state_manager.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""In-memory store for task state and execution traces; replaceable with persistent backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from fusionagi.schemas.task import Task, TaskState
|
||||
from fusionagi._logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.persistence import StateBackend
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""
|
||||
Manages task state and execution traces.
|
||||
|
||||
Supports optional persistent backend via dependency injection. When a backend
|
||||
is provided, all operations are persisted. In-memory cache is always maintained
|
||||
for fast access.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: StateBackend | None = None) -> None:
|
||||
"""
|
||||
Initialize StateManager with optional persistence backend.
|
||||
|
||||
Args:
|
||||
backend: Optional StateBackend for persistence. If None, uses in-memory only.
|
||||
"""
|
||||
self._backend = backend
|
||||
self._tasks: dict[str, Task] = {}
|
||||
self._traces: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
def get_task(self, task_id: str) -> Task | None:
|
||||
"""Return the task by id or None. Checks memory first, then backend."""
|
||||
if task_id in self._tasks:
|
||||
return self._tasks[task_id]
|
||||
if self._backend:
|
||||
task = self._backend.get_task(task_id)
|
||||
if task:
|
||||
self._tasks[task_id] = task
|
||||
return task
|
||||
return None
|
||||
|
||||
def set_task(self, task: Task) -> None:
|
||||
"""Store or update a task in memory and backend."""
|
||||
self._tasks[task.task_id] = task
|
||||
if self._backend:
|
||||
self._backend.set_task(task)
|
||||
logger.debug("Task set", extra={"task_id": task.task_id})
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None:
|
||||
"""Return current task state or None if task unknown."""
|
||||
task = self.get_task(task_id)
|
||||
return task.state if task else None
|
||||
|
||||
def set_task_state(self, task_id: str, state: TaskState) -> None:
|
||||
"""Update task state; creates no task if missing."""
|
||||
if task_id in self._tasks:
|
||||
self._tasks[task_id].state = state
|
||||
if self._backend:
|
||||
self._backend.set_task_state(task_id, state)
|
||||
logger.debug("Task state set", extra={"task_id": task_id, "state": state.value})
|
||||
elif self._backend:
|
||||
# Task might be in backend but not in memory
|
||||
task = self._backend.get_task(task_id)
|
||||
if task:
|
||||
task.state = state
|
||||
self._tasks[task_id] = task
|
||||
self._backend.set_task_state(task_id, state)
|
||||
logger.debug("Task state set", extra={"task_id": task_id, "state": state.value})
|
||||
|
||||
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
|
||||
"""Append an entry to the execution trace for a task."""
|
||||
self._traces[task_id].append(entry)
|
||||
if self._backend:
|
||||
self._backend.append_trace(task_id, entry)
|
||||
tool = entry.get("tool") or entry.get("step") or "entry"
|
||||
logger.debug("Trace appended", extra={"task_id": task_id, "entry_key": tool})
|
||||
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
|
||||
"""Return the execution trace for a task (copy). Checks backend if not in memory."""
|
||||
if task_id in self._traces and self._traces[task_id]:
|
||||
return list(self._traces[task_id])
|
||||
if self._backend:
|
||||
trace = self._backend.get_trace(task_id)
|
||||
if trace:
|
||||
self._traces[task_id] = list(trace)
|
||||
return trace
|
||||
return list(self._traces.get(task_id, []))
|
||||
|
||||
def clear_task(self, task_id: str) -> None:
|
||||
"""Remove task and its trace (for tests or cleanup). Does not clear backend."""
|
||||
self._tasks.pop(task_id, None)
|
||||
self._traces.pop(task_id, None)
|
||||
|
||||
def list_tasks(self, state: TaskState | None = None) -> list[Task]:
|
||||
"""Return all tasks, optionally filtered by state.
|
||||
|
||||
When a persistence backend is configured, only tasks currently loaded
|
||||
in memory are returned; tasks that exist only in the backend are not included.
|
||||
"""
|
||||
tasks = list(self._tasks.values())
|
||||
if state is not None:
|
||||
tasks = [t for t in tasks if t.state == state]
|
||||
return tasks
|
||||
|
||||
def task_count(self) -> int:
|
||||
"""Return total number of tasks in memory."""
|
||||
return len(self._tasks)
|
||||
136
fusionagi/core/super_big_brain.py
Normal file
136
fusionagi/core/super_big_brain.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Super Big Brain orchestrator: tokenless, recursive, graph-backed reasoning."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit, DecompositionResult
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim, HeadRisk
|
||||
from fusionagi.schemas.grounding import Citation
|
||||
from fusionagi.reasoning.decomposition import decompose_recursive
|
||||
from fusionagi.reasoning.context_loader import load_context_for_reasoning, build_compact_prompt
|
||||
from fusionagi.reasoning.tot import ThoughtNode, expand_node, prune_subtree, merge_subtrees
|
||||
from fusionagi.reasoning.multi_path import generate_and_score_parallel
|
||||
from fusionagi.reasoning.recomposition import recompose, RecomposedResponse
|
||||
from fusionagi.reasoning.meta_reasoning import challenge_assumptions, detect_contradictions
|
||||
from fusionagi.memory.semantic_graph import SemanticGraphMemory
|
||||
from fusionagi.memory.sharding import shard_context
|
||||
from fusionagi.memory.scratchpad import LatentScratchpad
|
||||
from fusionagi.memory.thought_versioning import ThoughtVersioning
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuperBigBrainConfig:
|
||||
"""Configuration for Super Big Brain pipeline."""
|
||||
|
||||
max_decomposition_depth: int = 3
|
||||
min_depth_before_conclusion: int = 1
|
||||
parallel_hypotheses: int = 3
|
||||
prune_threshold: float = 0.3
|
||||
max_context_chars: int = 4000
|
||||
|
||||
|
||||
def run_super_big_brain(
|
||||
prompt: str,
|
||||
semantic_graph: SemanticGraphMemory,
|
||||
config: SuperBigBrainConfig | None = None,
|
||||
adapter: Any | None = None,
|
||||
) -> RecomposedResponse:
|
||||
"""
|
||||
End-to-end Super Big Brain pipeline:
|
||||
|
||||
1. Decompose prompt -> atomic units
|
||||
2. Shard and load context
|
||||
3. Run hierarchical ToT with multi-path inference
|
||||
4. Recompose with traceability
|
||||
5. Persist units/relations to semantic graph
|
||||
"""
|
||||
cfg = config or SuperBigBrainConfig()
|
||||
decomp = decompose_recursive(prompt, max_depth=cfg.max_decomposition_depth)
|
||||
if not decomp.units:
|
||||
return RecomposedResponse(summary="No content to reason over.", confidence=0.0)
|
||||
|
||||
semantic_graph.ingest_decomposition(decomp.units, decomp.relations)
|
||||
ctx = load_context_for_reasoning(decomp.units, semantic_graph=semantic_graph, sharder=shard_context)
|
||||
compact = build_compact_prompt(decomp.units, max_chars=cfg.max_context_chars)
|
||||
|
||||
hypotheses = [u.content for u in decomp.units[:cfg.parallel_hypotheses] if u.content]
|
||||
if not hypotheses:
|
||||
hypotheses = [compact[:500]]
|
||||
|
||||
scored = generate_and_score_parallel(hypotheses, decomp.units)
|
||||
nodes = [n for n, _ in sorted(scored, key=lambda x: x[1], reverse=True)]
|
||||
best = nodes[0] if nodes else ThoughtNode(thought=compact[:300], unit_refs=[u.unit_id for u in decomp.units[:5]])
|
||||
|
||||
if cfg.min_depth_before_conclusion > 0 and best.depth < cfg.min_depth_before_conclusion:
|
||||
child = expand_node(best, compact[:200], unit_refs=best.unit_refs)
|
||||
child.score = best.score
|
||||
best = child
|
||||
|
||||
prune_subtree(best, cfg.prune_threshold)
|
||||
assumptions = challenge_assumptions(decomp.units, best.thought)
|
||||
contradictions = detect_contradictions(decomp.units)
|
||||
|
||||
recomp = recompose([best], decomp.units)
|
||||
recomp.metadata["assumptions_flagged"] = len(assumptions)
|
||||
recomp.metadata["contradictions"] = len(contradictions)
|
||||
recomp.metadata["depth"] = best.depth
|
||||
|
||||
logger.info(
|
||||
"Super Big Brain complete",
|
||||
extra={"units": len(decomp.units), "confidence": recomp.confidence},
|
||||
)
|
||||
return recomp
|
||||
|
||||
|
||||
def _recomposed_to_head_output(
|
||||
recomp: RecomposedResponse,
|
||||
head_id: HeadId,
|
||||
) -> HeadOutput:
|
||||
"""Convert RecomposedResponse to HeadOutput for Dvādaśa integration."""
|
||||
claims = [
|
||||
HeadClaim(
|
||||
claim_text=c,
|
||||
confidence=recomp.confidence,
|
||||
evidence=[Citation(source_id=uid, excerpt="", confidence=recomp.confidence) for uid in recomp.unit_refs[:3]],
|
||||
assumptions=[],
|
||||
)
|
||||
for c in recomp.key_claims[:5]
|
||||
]
|
||||
if not claims:
|
||||
claims = [
|
||||
HeadClaim(claim_text=recomp.summary, confidence=recomp.confidence, evidence=[], assumptions=[]),
|
||||
]
|
||||
risks = []
|
||||
if recomp.metadata.get("assumptions_flagged", 0) > 0:
|
||||
risks.append(HeadRisk(description="Assumptions flagged; verify before acting", severity="medium"))
|
||||
if recomp.metadata.get("contradictions", 0) > 0:
|
||||
risks.append(HeadRisk(description="Contradictions detected in context", severity="high"))
|
||||
return HeadOutput(
|
||||
head_id=head_id,
|
||||
summary=recomp.summary,
|
||||
claims=claims,
|
||||
risks=risks,
|
||||
questions=[],
|
||||
recommended_actions=["Consider flagged assumptions", "Resolve contradictions if any"],
|
||||
tone_guidance="",
|
||||
)
|
||||
|
||||
|
||||
class SuperBigBrainReasoningProvider:
|
||||
"""ReasoningProvider for HeadAgent: uses Super Big Brain pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
semantic_graph: SemanticGraphMemory | None = None,
|
||||
config: SuperBigBrainConfig | None = None,
|
||||
) -> None:
|
||||
self._graph = semantic_graph or SemanticGraphMemory()
|
||||
self._config = config or SuperBigBrainConfig()
|
||||
|
||||
def produce_head_output(self, head_id: HeadId, prompt: str) -> HeadOutput:
|
||||
"""Produce HeadOutput using Super Big Brain pipeline."""
|
||||
recomp = run_super_big_brain(prompt, self._graph, self._config)
|
||||
return _recomposed_to_head_output(recomp, head_id)
|
||||
Reference in New Issue
Block a user