Initial commit: add .gitignore and README
This commit is contained in:
41
fusionagi/multi_agent/__init__.py
Normal file
41
fusionagi/multi_agent/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Multi-agent: parallel, delegation, pooling, coordinator, adversarial reviewer, consensus."""
|
||||
|
||||
from fusionagi.multi_agent.parallel import (
|
||||
execute_steps_parallel,
|
||||
execute_steps_parallel_wave,
|
||||
ParallelStepResult,
|
||||
)
|
||||
from fusionagi.multi_agent.pool import AgentPool, PooledExecutorRouter
|
||||
from fusionagi.multi_agent.supervisor import SupervisorAgent
|
||||
from fusionagi.multi_agent.delegation import (
|
||||
delegate_sub_tasks,
|
||||
DelegationConfig,
|
||||
SubTask,
|
||||
SubTaskResult,
|
||||
)
|
||||
from fusionagi.multi_agent.coordinator import CoordinatorAgent
|
||||
from fusionagi.multi_agent.consensus import consensus_vote, arbitrate
|
||||
from fusionagi.multi_agent.consensus_engine import (
|
||||
run_consensus,
|
||||
collect_claims,
|
||||
CollectedClaim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"execute_steps_parallel",
|
||||
"execute_steps_parallel_wave",
|
||||
"ParallelStepResult",
|
||||
"AgentPool",
|
||||
"PooledExecutorRouter",
|
||||
"SupervisorAgent",
|
||||
"delegate_sub_tasks",
|
||||
"DelegationConfig",
|
||||
"SubTask",
|
||||
"SubTaskResult",
|
||||
"CoordinatorAgent",
|
||||
"consensus_vote",
|
||||
"arbitrate",
|
||||
"run_consensus",
|
||||
"collect_claims",
|
||||
"CollectedClaim",
|
||||
]
|
||||
15
fusionagi/multi_agent/consensus.py
Normal file
15
fusionagi/multi_agent/consensus.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Any
|
||||
from collections import Counter
|
||||
from fusionagi._logger import logger
|
||||
|
||||
def consensus_vote(answers: list, key=None):
|
||||
if not answers:
|
||||
return None
|
||||
values = [a.get(key, a) if isinstance(a, dict) else a for a in answers] if key else list(answers)
|
||||
return Counter(values).most_common(1)[0][0]
|
||||
|
||||
def arbitrate(proposals: list, arbitrator="coordinator"):
|
||||
if not proposals:
|
||||
return {}
|
||||
logger.info("Arbitrate", extra={"arbitrator": arbitrator, "count": len(proposals)})
|
||||
return proposals[0]
|
||||
194
fusionagi/multi_agent/consensus_engine.py
Normal file
194
fusionagi/multi_agent/consensus_engine.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Consensus engine: claim collection, deduplication, conflict detection, scoring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim
|
||||
from fusionagi.schemas.witness import AgreementMap
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectedClaim:
|
||||
"""Claim with source head and metadata for consensus."""
|
||||
|
||||
claim_text: str
|
||||
confidence: float
|
||||
head_id: HeadId
|
||||
evidence_count: int
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def _normalize_text(s: str) -> str:
|
||||
"""Normalize for duplicate detection."""
|
||||
return " ".join(s.lower().split())
|
||||
|
||||
|
||||
def _are_similar(a: str, b: str, threshold: float = 0.9) -> bool:
|
||||
"""Simple similarity: exact match or one contains the other (normalized)."""
|
||||
na, nb = _normalize_text(a), _normalize_text(b)
|
||||
if na == nb:
|
||||
return True
|
||||
if len(na) < 10 or len(nb) < 10:
|
||||
return na == nb
|
||||
# Jaccard-like: word overlap
|
||||
wa, wb = set(na.split()), set(nb.split())
|
||||
inter = len(wa & wb)
|
||||
union = len(wa | wb)
|
||||
if union == 0:
|
||||
return False
|
||||
return (inter / union) >= threshold
|
||||
|
||||
|
||||
def _looks_contradictory(a: str, b: str) -> bool:
|
||||
"""Heuristic: same subject with opposite polarity indicators."""
|
||||
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
|
||||
na, nb = _normalize_text(a), _normalize_text(b)
|
||||
wa, wb = set(na.split()), set(nb.split())
|
||||
# If one has neg and the other doesn't, and they share significant overlap
|
||||
a_neg = bool(wa & neg_words)
|
||||
b_neg = bool(wb & neg_words)
|
||||
if a_neg != b_neg:
|
||||
overlap = len(wa & wb) / max(len(wa), 1)
|
||||
if overlap > 0.3:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def collect_claims(outputs: list[HeadOutput]) -> list[CollectedClaim]:
|
||||
"""Flatten all head claims with source metadata."""
|
||||
collected: list[CollectedClaim] = []
|
||||
for out in outputs:
|
||||
for c in out.claims:
|
||||
collected.append(
|
||||
CollectedClaim(
|
||||
claim_text=c.claim_text,
|
||||
confidence=c.confidence,
|
||||
head_id=out.head_id,
|
||||
evidence_count=len(c.evidence),
|
||||
raw={
|
||||
"claim_text": c.claim_text,
|
||||
"confidence": c.confidence,
|
||||
"head_id": out.head_id.value,
|
||||
"evidence_count": len(c.evidence),
|
||||
"assumptions": c.assumptions,
|
||||
},
|
||||
)
|
||||
)
|
||||
return collected
|
||||
|
||||
|
||||
def run_consensus(
|
||||
outputs: list[HeadOutput],
|
||||
head_weights: dict[HeadId, float] | None = None,
|
||||
confidence_threshold: float = 0.5,
|
||||
) -> AgreementMap:
|
||||
"""
|
||||
Run consensus: deduplicate, detect conflicts, score, produce AgreementMap.
|
||||
|
||||
Args:
|
||||
outputs: HeadOutput from all content heads.
|
||||
head_weights: Optional per-head reliability weights (default 1.0).
|
||||
confidence_threshold: Minimum confidence for agreed claim.
|
||||
|
||||
Returns:
|
||||
AgreementMap with agreed_claims, disputed_claims, confidence_score.
|
||||
"""
|
||||
if not outputs:
|
||||
return AgreementMap(
|
||||
agreed_claims=[],
|
||||
disputed_claims=[],
|
||||
confidence_score=0.0,
|
||||
)
|
||||
|
||||
weights = head_weights or {h: 1.0 for h in HeadId}
|
||||
collected = collect_claims(outputs)
|
||||
|
||||
# Group by similarity (merge near-duplicates)
|
||||
merged: list[CollectedClaim] = []
|
||||
used: set[int] = set()
|
||||
for i, ca in enumerate(collected):
|
||||
if i in used:
|
||||
continue
|
||||
group = [ca]
|
||||
used.add(i)
|
||||
for j, cb in enumerate(collected):
|
||||
if j in used:
|
||||
continue
|
||||
if _are_similar(ca.claim_text, cb.claim_text) and not _looks_contradictory(ca.claim_text, cb.claim_text):
|
||||
group.append(cb)
|
||||
used.add(j)
|
||||
# Aggregate: weighted avg confidence, combine heads
|
||||
if len(group) == 1:
|
||||
c = group[0]
|
||||
score = c.confidence * weights.get(c.head_id, 1.0)
|
||||
if c.evidence_count > 0:
|
||||
score *= 1.1 # boost for citations
|
||||
merged.append(
|
||||
CollectedClaim(
|
||||
claim_text=c.claim_text,
|
||||
confidence=score,
|
||||
head_id=c.head_id,
|
||||
evidence_count=c.evidence_count,
|
||||
raw={**c.raw, "aggregated_confidence": score, "supporting_heads": [c.head_id.value]},
|
||||
)
|
||||
)
|
||||
else:
|
||||
total_conf = sum(g.confidence * weights.get(g.head_id, 1.0) for g in group)
|
||||
avg_conf = total_conf / len(group)
|
||||
evidence_boost = 1.1 if any(g.evidence_count > 0 for g in group) else 1.0
|
||||
score = min(1.0, avg_conf * evidence_boost)
|
||||
merged.append(
|
||||
CollectedClaim(
|
||||
claim_text=group[0].claim_text,
|
||||
confidence=score,
|
||||
head_id=group[0].head_id,
|
||||
evidence_count=sum(g.evidence_count for g in group),
|
||||
raw={
|
||||
"claim_text": group[0].claim_text,
|
||||
"aggregated_confidence": score,
|
||||
"supporting_heads": [g.head_id.value for g in group],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Conflict detection
|
||||
agreed: list[dict[str, Any]] = []
|
||||
disputed: list[dict[str, Any]] = []
|
||||
|
||||
for c in merged:
|
||||
in_conflict = False
|
||||
for d in merged:
|
||||
if c is d:
|
||||
continue
|
||||
if _looks_contradictory(c.claim_text, d.claim_text):
|
||||
in_conflict = True
|
||||
break
|
||||
rec = {
|
||||
"claim_text": c.claim_text,
|
||||
"confidence": c.confidence,
|
||||
"supporting_heads": c.raw.get("supporting_heads", [c.head_id.value]),
|
||||
}
|
||||
if in_conflict or c.confidence < confidence_threshold:
|
||||
disputed.append(rec)
|
||||
else:
|
||||
agreed.append(rec)
|
||||
|
||||
overall_conf = (
|
||||
sum(a["confidence"] for a in agreed) / len(agreed)
|
||||
if agreed
|
||||
else 0.0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Consensus complete",
|
||||
extra={"agreed": len(agreed), "disputed": len(disputed), "confidence": overall_conf},
|
||||
)
|
||||
|
||||
return AgreementMap(
|
||||
agreed_claims=agreed,
|
||||
disputed_claims=disputed,
|
||||
confidence_score=min(1.0, overall_conf),
|
||||
)
|
||||
18
fusionagi/multi_agent/coordinator.py
Normal file
18
fusionagi/multi_agent/coordinator.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.schemas.messages import AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.orchestrator import Orchestrator
|
||||
from fusionagi.core.goal_manager import GoalManager
|
||||
|
||||
class CoordinatorAgent(BaseAgent):
|
||||
def __init__(self, identity="coordinator", orchestrator=None, goal_manager=None, planner_id="planner"):
|
||||
super().__init__(identity=identity, role="Coordinator", objective="Own goals and assign tasks", memory_access=True, tool_permissions=[])
|
||||
self._orch = orchestrator
|
||||
self._goal_manager = goal_manager
|
||||
self._planner_id = planner_id
|
||||
def handle_message(self, envelope):
|
||||
if envelope.message.intent == "goal_created" and self._orch and self._planner_id:
|
||||
self._orch.route_message(envelope)
|
||||
return None
|
||||
97
fusionagi/multi_agent/delegation.py
Normal file
97
fusionagi/multi_agent/delegation.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Sub-task delegation: fan-out to sub-agents, fan-in of results.
|
||||
|
||||
Enables hierarchical multi-agent: a supervisor decomposes a task into
|
||||
sub-tasks, delegates to specialized sub-agents, and aggregates results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTask:
|
||||
"""A sub-task to delegate."""
|
||||
|
||||
sub_task_id: str
|
||||
goal: str
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTaskResult:
|
||||
"""Result from a delegated sub-task."""
|
||||
|
||||
sub_task_id: str
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelegationConfig:
|
||||
"""Configuration for delegation behavior."""
|
||||
|
||||
max_parallel: int = 4
|
||||
timeout_seconds: float | None = None
|
||||
fail_fast: bool = False # Stop on first failure
|
||||
|
||||
|
||||
def delegate_sub_tasks(
|
||||
sub_tasks: list[SubTask],
|
||||
delegate_fn: Callable[[SubTask], SubTaskResult],
|
||||
config: DelegationConfig | None = None,
|
||||
) -> list[SubTaskResult]:
|
||||
"""
|
||||
Fan-out: delegate sub-tasks to sub-agents in parallel.
|
||||
|
||||
Args:
|
||||
sub_tasks: List of sub-tasks to delegate.
|
||||
delegate_fn: (SubTask) -> SubTaskResult. Typically calls orchestrator
|
||||
to submit task and route to sub-agent, then wait for completion.
|
||||
config: Delegation behavior.
|
||||
|
||||
Returns:
|
||||
List of SubTaskResult in same order as sub_tasks.
|
||||
"""
|
||||
cfg = config or DelegationConfig()
|
||||
results: list[SubTaskResult] = [None] * len(sub_tasks) # type: ignore
|
||||
index_map = {st.sub_task_id: i for i, st in enumerate(sub_tasks)}
|
||||
|
||||
def run_one(st: SubTask) -> tuple[int, SubTaskResult]:
|
||||
r = delegate_fn(st)
|
||||
return index_map[st.sub_task_id], r
|
||||
|
||||
with ThreadPoolExecutor(max_workers=cfg.max_parallel) as executor:
|
||||
future_to_task = {executor.submit(run_one, st): st for st in sub_tasks}
|
||||
for future in as_completed(future_to_task):
|
||||
idx, result = future.result()
|
||||
results[idx] = result
|
||||
if cfg.fail_fast and not result.success:
|
||||
logger.warning("Delegation fail_fast on failure", extra={"sub_task_id": result.sub_task_id})
|
||||
break
|
||||
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
def aggregate_sub_task_results(
|
||||
results: list[SubTaskResult],
|
||||
aggregator: Callable[[list[SubTaskResult]], Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Fan-in: aggregate sub-task results into a single outcome.
|
||||
|
||||
Args:
|
||||
results: Results from delegate_sub_tasks.
|
||||
aggregator: (results) -> aggregated value.
|
||||
|
||||
Returns:
|
||||
Aggregated result.
|
||||
"""
|
||||
return aggregator(results)
|
||||
144
fusionagi/multi_agent/parallel.py
Normal file
144
fusionagi/multi_agent/parallel.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Parallel step execution: run independent plan steps concurrently.
|
||||
|
||||
Multi-agent acceleration: steps with satisfied dependencies and no mutual
|
||||
dependencies are dispatched in parallel to maximize throughput.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from fusionagi.schemas.plan import Plan
|
||||
from fusionagi.planning import ready_steps, get_step
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelStepResult:
|
||||
"""Result of a single step execution in parallel batch."""
|
||||
|
||||
step_id: str
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
envelope: Any = None # AgentMessageEnvelope from executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteStepsCallback(Protocol):
|
||||
"""Protocol for executing a single step (e.g. via orchestrator)."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
plan: Plan,
|
||||
sender: str = "supervisor",
|
||||
) -> Any:
|
||||
"""Execute one step; return response envelope or result."""
|
||||
...
|
||||
|
||||
|
||||
def execute_steps_parallel(
|
||||
execute_fn: Callable[[str, str, Plan, str], Any],
|
||||
task_id: str,
|
||||
plan: Plan,
|
||||
completed_step_ids: set[str],
|
||||
sender: str = "supervisor",
|
||||
max_workers: int | None = None,
|
||||
) -> list[ParallelStepResult]:
|
||||
"""
|
||||
Execute all ready steps in parallel.
|
||||
|
||||
Args:
|
||||
execute_fn: Function (task_id, step_id, plan, sender) -> response.
|
||||
task_id: Task identifier.
|
||||
plan: The plan containing steps.
|
||||
completed_step_ids: Steps already completed.
|
||||
sender: Sender identity for execute messages.
|
||||
max_workers: Max parallel workers (default: unbounded for ready steps).
|
||||
|
||||
Returns:
|
||||
List of ParallelStepResult, one per step attempted.
|
||||
"""
|
||||
ready = ready_steps(plan, completed_step_ids)
|
||||
if not ready:
|
||||
return []
|
||||
|
||||
results: list[ParallelStepResult] = []
|
||||
workers = max_workers if max_workers and max_workers > 0 else len(ready)
|
||||
|
||||
def run_one(step_id: str) -> ParallelStepResult:
|
||||
try:
|
||||
response = execute_fn(task_id, step_id, plan, sender)
|
||||
if response is None:
|
||||
return ParallelStepResult(step_id=step_id, success=False, error="No response")
|
||||
# Response may be AgentMessageEnvelope with intent step_done/step_failed
|
||||
if hasattr(response, "message"):
|
||||
msg = response.message
|
||||
if msg.intent == "step_done":
|
||||
payload = msg.payload or {}
|
||||
return ParallelStepResult(
|
||||
step_id=step_id,
|
||||
success=True,
|
||||
result=payload.get("result"),
|
||||
envelope=response,
|
||||
)
|
||||
return ParallelStepResult(
|
||||
step_id=step_id,
|
||||
success=False,
|
||||
error=msg.payload.get("error", "Unknown failure") if msg.payload else "Unknown",
|
||||
envelope=response,
|
||||
)
|
||||
return ParallelStepResult(step_id=step_id, success=True, result=response)
|
||||
except Exception as e:
|
||||
logger.exception("Parallel step execution failed", extra={"step_id": step_id})
|
||||
return ParallelStepResult(step_id=step_id, success=False, error=str(e))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_step = {executor.submit(run_one, sid): sid for sid in ready}
|
||||
for future in as_completed(future_to_step):
|
||||
results.append(future.result())
|
||||
|
||||
logger.info(
|
||||
"Parallel step batch completed",
|
||||
extra={"task_id": task_id, "steps": ready, "results": len(results)},
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def execute_steps_parallel_wave(
|
||||
execute_fn: Callable[[str, str, Plan, str], Any],
|
||||
task_id: str,
|
||||
plan: Plan,
|
||||
sender: str = "supervisor",
|
||||
max_workers: int | None = None,
|
||||
) -> list[ParallelStepResult]:
|
||||
"""
|
||||
Execute plan in waves: each wave runs all ready steps in parallel,
|
||||
then advances to the next wave when deps are satisfied.
|
||||
|
||||
Returns combined results from all waves.
|
||||
"""
|
||||
completed: set[str] = set()
|
||||
all_results: list[ParallelStepResult] = []
|
||||
|
||||
while True:
|
||||
batch = execute_steps_parallel(
|
||||
execute_fn, task_id, plan, completed, sender, max_workers
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for r in batch:
|
||||
all_results.append(r)
|
||||
if r.success:
|
||||
completed.add(r.step_id)
|
||||
else:
|
||||
# On failure, stop the wave (caller can retry or handle)
|
||||
logger.warning("Step failed in wave, stopping", extra={"step_id": r.step_id})
|
||||
return all_results
|
||||
|
||||
return all_results
|
||||
190
fusionagi/multi_agent/pool.py
Normal file
190
fusionagi/multi_agent/pool.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Agent pool: load-balanced routing for horizontal scaling.
|
||||
|
||||
Multiple executor (or other) agents behind a single logical endpoint.
|
||||
Supports round-robin, least-busy, and random selection strategies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledAgent:
|
||||
"""An agent in the pool with load tracking."""
|
||||
|
||||
agent_id: str
|
||||
agent: Any # AgentProtocol
|
||||
in_flight: int = 0
|
||||
total_dispatched: int = 0
|
||||
last_used: float = field(default_factory=time.monotonic)
|
||||
|
||||
|
||||
class AgentPool:
|
||||
"""
|
||||
Pool of agents for load-balanced dispatch.
|
||||
|
||||
Strategies:
|
||||
- round_robin: Rotate through agents in order.
|
||||
- least_busy: Prefer agent with lowest in_flight count.
|
||||
- random: Random selection (useful for load spreading).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str = "least_busy",
|
||||
) -> None:
|
||||
self._strategy = strategy
|
||||
self._agents: list[PooledAgent] = []
|
||||
self._round_robin_index = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add(self, agent_id: str, agent: Any) -> None:
|
||||
"""Add an agent to the pool."""
|
||||
with self._lock:
|
||||
if any(p.agent_id == agent_id for p in self._agents):
|
||||
return
|
||||
self._agents.append(PooledAgent(agent_id=agent_id, agent=agent))
|
||||
logger.info("Agent added to pool", extra={"agent_id": agent_id, "pool_size": len(self._agents)})
|
||||
|
||||
def remove(self, agent_id: str) -> bool:
|
||||
"""Remove an agent from the pool."""
|
||||
with self._lock:
|
||||
for i, p in enumerate(self._agents):
|
||||
if p.agent_id == agent_id:
|
||||
self._agents.pop(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _select(self) -> PooledAgent | None:
|
||||
"""Select an agent based on strategy."""
|
||||
with self._lock:
|
||||
if not self._agents:
|
||||
return None
|
||||
|
||||
if self._strategy == "round_robin":
|
||||
idx = self._round_robin_index % len(self._agents)
|
||||
self._round_robin_index += 1
|
||||
return self._agents[idx]
|
||||
|
||||
if self._strategy == "random":
|
||||
return random.choice(self._agents)
|
||||
|
||||
# least_busy (default)
|
||||
return min(self._agents, key=lambda p: (p.in_flight, p.last_used))
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
envelope: AgentMessageEnvelope,
|
||||
on_complete: Callable[[str], None] | None = None,
|
||||
rewrite_recipient: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Dispatch to a pooled agent and return response.
|
||||
|
||||
Tracks in-flight for least_busy; calls on_complete(agent_id) when done
|
||||
if provided (for async cleanup).
|
||||
|
||||
If rewrite_recipient, the envelope's recipient is set to the selected
|
||||
agent's id so the agent receives it correctly.
|
||||
"""
|
||||
pooled = self._select()
|
||||
if not pooled:
|
||||
logger.error("Agent pool empty, cannot dispatch")
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
pooled.in_flight += 1
|
||||
pooled.total_dispatched += 1
|
||||
pooled.last_used = time.monotonic()
|
||||
|
||||
# Rewrite recipient so pooled agent receives correctly
|
||||
if rewrite_recipient:
|
||||
msg = envelope.message
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=msg.sender,
|
||||
recipient=pooled.agent_id,
|
||||
intent=msg.intent,
|
||||
payload=msg.payload,
|
||||
confidence=msg.confidence,
|
||||
uncertainty=msg.uncertainty,
|
||||
timestamp=msg.timestamp,
|
||||
),
|
||||
task_id=envelope.task_id,
|
||||
correlation_id=envelope.correlation_id,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = pooled.agent
|
||||
if hasattr(agent, "handle_message"):
|
||||
response = agent.handle_message(envelope)
|
||||
# Ensure response recipient points back to original sender
|
||||
return response
|
||||
return None
|
||||
finally:
|
||||
with self._lock:
|
||||
pooled.in_flight = max(0, pooled.in_flight - 1)
|
||||
if on_complete:
|
||||
on_complete(pooled.agent_id)
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return pool size."""
|
||||
return len(self._agents)
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Return pool statistics for monitoring."""
|
||||
with self._lock:
|
||||
return {
|
||||
"strategy": self._strategy,
|
||||
"size": len(self._agents),
|
||||
"agents": [
|
||||
{
|
||||
"id": p.agent_id,
|
||||
"in_flight": p.in_flight,
|
||||
"total_dispatched": p.total_dispatched,
|
||||
}
|
||||
for p in self._agents
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PooledExecutorRouter:
|
||||
"""
|
||||
Routes execute_step messages to a pool of executors.
|
||||
|
||||
Wraps multiple ExecutorAgent instances; orchestrator or supervisor
|
||||
sends to this router's identity, and it load-balances to the pool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "executor_pool",
|
||||
pool: AgentPool | None = None,
|
||||
) -> None:
|
||||
self.identity = identity
|
||||
self._pool = pool or AgentPool(strategy="least_busy")
|
||||
|
||||
def add_executor(self, executor_id: str, executor: Any) -> None:
|
||||
"""Add an executor to the pool."""
|
||||
self._pool.add(executor_id, executor)
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""Route execute_step to pooled executor; return response."""
|
||||
if envelope.message.intent != "execute_step":
|
||||
return None
|
||||
|
||||
# Rewrite recipient so response comes back to original sender
|
||||
response = self._pool.dispatch(envelope)
|
||||
return response
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Pool statistics."""
|
||||
return self._pool.stats()
|
||||
191
fusionagi/multi_agent/supervisor.py
Normal file
191
fusionagi/multi_agent/supervisor.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Supervisor agent: drives the orchestration loop with parallel dispatch.
|
||||
|
||||
Coordinates Planner -> Reasoner -> Executor flow. Supports:
|
||||
- Parallel step execution (independent steps run concurrently)
|
||||
- Pooled executor routing (load-balanced across N executors)
|
||||
- Batch task processing (multiple tasks in flight)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from fusionagi.agents.base_agent import BaseAgent
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
from fusionagi.schemas.plan import Plan
|
||||
from fusionagi.planning import ready_steps, get_step
|
||||
from fusionagi.multi_agent.parallel import execute_steps_parallel, execute_steps_parallel_wave
|
||||
from fusionagi._logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.core.orchestrator import Orchestrator
|
||||
|
||||
|
||||
class SupervisorAgent(BaseAgent):
|
||||
"""
|
||||
Supervisor: drives plan-execute loop with multi-agent accelerations.
|
||||
|
||||
Features:
|
||||
- Parallel step execution (ready_steps dispatched concurrently)
|
||||
- Configurable execution mode: sequential, parallel, or wave
|
||||
- Integration with Orchestrator for message routing
|
||||
- Optional pooled executor for horizontal scaling
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: str = "supervisor",
|
||||
orchestrator: Orchestrator | None = None,
|
||||
planner_id: str = "planner",
|
||||
reasoner_id: str = "reasoner",
|
||||
executor_id: str = "executor",
|
||||
parallel_mode: bool = True,
|
||||
max_parallel_workers: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
identity: Supervisor agent id.
|
||||
orchestrator: Orchestrator for routing (required for full loop).
|
||||
planner_id: Registered planner agent id.
|
||||
reasoner_id: Registered reasoner agent id.
|
||||
executor_id: Registered executor (or pool) agent id.
|
||||
parallel_mode: If True, use parallel step execution.
|
||||
max_parallel_workers: Cap on concurrent step executions.
|
||||
"""
|
||||
super().__init__(
|
||||
identity=identity,
|
||||
role="Supervisor",
|
||||
objective="Coordinate plan-execute loop with parallel dispatch",
|
||||
memory_access=True,
|
||||
tool_permissions=[],
|
||||
)
|
||||
self._orch = orchestrator
|
||||
self._planner_id = planner_id
|
||||
self._reasoner_id = reasoner_id
|
||||
self._executor_id = executor_id
|
||||
self._parallel_mode = parallel_mode
|
||||
self._max_workers = max_parallel_workers
|
||||
|
||||
def _route(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""Route message via orchestrator and return response."""
|
||||
if not self._orch:
|
||||
return None
|
||||
return self._orch.route_message_return(envelope)
|
||||
|
||||
def _execute_step(self, task_id: str, step_id: str, plan: Plan, sender: str) -> Any:
|
||||
"""Execute a single step by routing to executor."""
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=sender,
|
||||
recipient=self._executor_id,
|
||||
intent="execute_step",
|
||||
payload={"step_id": step_id, "plan": plan.to_dict()},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
return self._route(envelope)
|
||||
|
||||
def _request_plan(self, task_id: str, goal: str, constraints: list[str]) -> Plan | None:
|
||||
"""Request plan from planner."""
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender=self.identity,
|
||||
recipient=self._planner_id,
|
||||
intent="plan_request",
|
||||
payload={"goal": goal, "constraints": constraints},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
resp = self._route(envelope)
|
||||
if not resp or not resp.message.payload:
|
||||
return None
|
||||
plan_dict = resp.message.payload.get("plan")
|
||||
if not plan_dict:
|
||||
return None
|
||||
return Plan.from_dict(plan_dict)
|
||||
|
||||
def handle_message(self, envelope: AgentMessageEnvelope) -> AgentMessageEnvelope | None:
|
||||
"""
|
||||
On run_task or similar: get plan, execute steps (parallel or sequential).
|
||||
"""
|
||||
if envelope.message.intent not in ("run_task", "execute_plan"):
|
||||
return None
|
||||
|
||||
payload = envelope.message.payload or {}
|
||||
task_id = envelope.task_id or ""
|
||||
goal = payload.get("goal", "")
|
||||
constraints = payload.get("constraints", [])
|
||||
plan_dict = payload.get("plan")
|
||||
|
||||
logger.info(
|
||||
"Supervisor handling run_task",
|
||||
extra={"task_id": task_id, "parallel": self._parallel_mode},
|
||||
)
|
||||
|
||||
if not self._orch:
|
||||
return envelope.create_response(
|
||||
"run_failed",
|
||||
payload={"error": "No orchestrator configured"},
|
||||
)
|
||||
|
||||
# Get plan
|
||||
if plan_dict:
|
||||
plan = Plan.from_dict(plan_dict)
|
||||
else:
|
||||
plan = self._request_plan(task_id, goal, constraints)
|
||||
if not plan:
|
||||
return envelope.create_response(
|
||||
"run_failed",
|
||||
payload={"error": "Failed to get plan"},
|
||||
)
|
||||
|
||||
# Execute steps
|
||||
if self._parallel_mode:
|
||||
results = execute_steps_parallel_wave(
|
||||
self._execute_step,
|
||||
task_id,
|
||||
plan,
|
||||
sender=self.identity,
|
||||
max_workers=self._max_workers,
|
||||
)
|
||||
successes = sum(1 for r in results if r.success)
|
||||
failures = [r for r in results if not r.success]
|
||||
if failures:
|
||||
return envelope.create_response(
|
||||
"run_failed",
|
||||
payload={
|
||||
"error": f"Step(s) failed: {[f.step_id for f in failures]}",
|
||||
"results": [
|
||||
{"step_id": r.step_id, "success": r.success, "error": r.error}
|
||||
for r in results
|
||||
],
|
||||
},
|
||||
)
|
||||
return envelope.create_response(
|
||||
"run_completed",
|
||||
payload={
|
||||
"steps_completed": successes,
|
||||
"results": [{"step_id": r.step_id, "result": r.result} for r in results],
|
||||
},
|
||||
)
|
||||
|
||||
# Sequential fallback
|
||||
completed: set[str] = set()
|
||||
while True:
|
||||
ready = ready_steps(plan, completed)
|
||||
if not ready:
|
||||
break
|
||||
step_id = ready[0]
|
||||
resp = self._execute_step(task_id, step_id, plan, self.identity)
|
||||
if resp and hasattr(resp, "message") and resp.message.intent == "step_done":
|
||||
completed.add(step_id)
|
||||
else:
|
||||
return envelope.create_response(
|
||||
"run_failed",
|
||||
payload={"error": f"Step {step_id} failed", "step_id": step_id},
|
||||
)
|
||||
|
||||
return envelope.create_response(
|
||||
"run_completed",
|
||||
payload={"steps_completed": len(completed)},
|
||||
)
|
||||
Reference in New Issue
Block a user