Initial commit: add .gitignore and README
This commit is contained in:
227
tests/test_multi_agent.py
Normal file
227
tests/test_multi_agent.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Tests for multi-agent accelerations: parallel execution, pool, delegation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from fusionagi.planning import ready_steps
|
||||
from fusionagi.schemas.plan import Plan, PlanStep
|
||||
from fusionagi.multi_agent import (
|
||||
execute_steps_parallel,
|
||||
ParallelStepResult,
|
||||
AgentPool,
|
||||
PooledExecutorRouter,
|
||||
delegate_sub_tasks,
|
||||
DelegationConfig,
|
||||
SubTask,
|
||||
)
|
||||
from fusionagi.core import EventBus, StateManager, Orchestrator
|
||||
from fusionagi.agents import ExecutorAgent, PlannerAgent
|
||||
from fusionagi.tools import ToolRegistry
|
||||
from fusionagi.adapters import StubAdapter
|
||||
|
||||
|
||||
class TestReadySteps:
|
||||
"""Test ready_steps for parallel dispatch."""
|
||||
|
||||
def test_parallel_ready_steps(self):
|
||||
"""Steps with same deps are ready together."""
|
||||
plan = Plan(
|
||||
steps=[
|
||||
PlanStep(id="s1", description="First"),
|
||||
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
||||
PlanStep(id="s3", description="Third", dependencies=["s1"]),
|
||||
PlanStep(id="s4", description="Fourth", dependencies=["s2", "s3"]),
|
||||
]
|
||||
)
|
||||
assert ready_steps(plan, set()) == ["s1"]
|
||||
assert set(ready_steps(plan, {"s1"})) == {"s2", "s3"}
|
||||
assert ready_steps(plan, {"s1", "s2", "s3"}) == ["s4"]
|
||||
assert ready_steps(plan, {"s1", "s2", "s3", "s4"}) == []
|
||||
|
||||
|
||||
class TestAgentPool:
|
||||
"""Test AgentPool and PooledExecutorRouter."""
|
||||
|
||||
def test_pool_round_robin(self):
|
||||
"""Round-robin selection rotates through agents."""
|
||||
pool = AgentPool(strategy="round_robin")
|
||||
calls = []
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, aid):
|
||||
self.identity = aid
|
||||
|
||||
def handle_message(self, env):
|
||||
calls.append(self.identity)
|
||||
return None
|
||||
|
||||
pool.add("a1", FakeAgent("a1"))
|
||||
pool.add("a2", FakeAgent("a2"))
|
||||
pool.add("a3", FakeAgent("a3"))
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
|
||||
env = AgentMessageEnvelope(
|
||||
message=AgentMessage(sender="x", recipient="pool", intent="execute_step", payload={}),
|
||||
task_id="t1",
|
||||
)
|
||||
for _ in range(6):
|
||||
pool.dispatch(env)
|
||||
|
||||
assert calls == ["a1", "a2", "a3", "a1", "a2", "a3"]
|
||||
|
||||
def test_pool_least_busy(self):
|
||||
"""Least-busy prefers agent with fewest in-flight."""
|
||||
pool = AgentPool(strategy="least_busy")
|
||||
|
||||
class SlowAgent:
|
||||
def __init__(self, aid):
|
||||
self.identity = aid
|
||||
|
||||
def handle_message(self, env):
|
||||
import time
|
||||
time.sleep(0.05)
|
||||
return None
|
||||
|
||||
pool.add("slow1", SlowAgent("slow1"))
|
||||
pool.add("slow2", SlowAgent("slow2"))
|
||||
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
|
||||
env = AgentMessageEnvelope(
|
||||
message=AgentMessage(sender="x", recipient="pool", intent="test", payload={}),
|
||||
)
|
||||
# Sequential dispatch - both should get used
|
||||
pool.dispatch(env)
|
||||
pool.dispatch(env)
|
||||
stats = pool.stats()
|
||||
assert stats["size"] == 2
|
||||
assert sum(a["total_dispatched"] for a in stats["agents"]) == 2
|
||||
|
||||
def test_pooled_executor_router(self):
|
||||
"""PooledExecutorRouter routes to pool."""
|
||||
registry = ToolRegistry()
|
||||
state = StateManager()
|
||||
exec1 = ExecutorAgent(identity="exec1", registry=registry, state_manager=state)
|
||||
exec2 = ExecutorAgent(identity="exec2", registry=registry, state_manager=state)
|
||||
|
||||
router = PooledExecutorRouter(identity="executor_pool")
|
||||
router.add_executor("exec1", exec1)
|
||||
router.add_executor("exec2", exec2)
|
||||
|
||||
assert router.stats()["size"] == 2
|
||||
|
||||
|
||||
class TestDelegation:
|
||||
"""Test sub-task delegation."""
|
||||
|
||||
def test_delegate_sub_tasks_parallel(self):
|
||||
"""Delegation runs sub-tasks in parallel."""
|
||||
results_received = []
|
||||
|
||||
def delegate_fn(st: SubTask) -> dict:
|
||||
results_received.append(st.sub_task_id)
|
||||
return dict(
|
||||
sub_task_id=st.sub_task_id,
|
||||
success=True,
|
||||
result=f"done-{st.sub_task_id}",
|
||||
agent_id="agent1",
|
||||
)
|
||||
|
||||
# Wrap to return SubTaskResult
|
||||
def wrapped(st):
|
||||
r = delegate_fn(st)
|
||||
from fusionagi.multi_agent.delegation import SubTaskResult
|
||||
return SubTaskResult(
|
||||
sub_task_id=r["sub_task_id"],
|
||||
success=r["success"],
|
||||
result=r["result"],
|
||||
agent_id=r.get("agent_id"),
|
||||
)
|
||||
|
||||
tasks = [
|
||||
SubTask("t1", "Goal 1"),
|
||||
SubTask("t2", "Goal 2"),
|
||||
SubTask("t3", "Goal 3"),
|
||||
]
|
||||
config = DelegationConfig(max_parallel=3)
|
||||
results = delegate_sub_tasks(tasks, wrapped, config)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(r.success for r in results)
|
||||
assert set(r.sub_task_id for r in results) == {"t1", "t2", "t3"}
|
||||
|
||||
|
||||
class TestParallelExecution:
|
||||
"""Test parallel step execution."""
|
||||
|
||||
def test_execute_steps_parallel(self):
|
||||
"""Parallel execution runs ready steps concurrently."""
|
||||
completed = []
|
||||
|
||||
def execute_fn(task_id, step_id, plan, sender):
|
||||
completed.append(step_id)
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender="executor",
|
||||
recipient=sender,
|
||||
intent="step_done",
|
||||
payload={"step_id": step_id, "result": f"ok-{step_id}"},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
plan = Plan(
|
||||
steps=[
|
||||
PlanStep(id="s1", description="First"),
|
||||
PlanStep(id="s2", description="Second", dependencies=["s1"]),
|
||||
PlanStep(id="s3", description="Third", dependencies=["s1"]),
|
||||
]
|
||||
)
|
||||
# s2 and s3 are ready when s1 is done
|
||||
results = execute_steps_parallel(
|
||||
execute_fn, "task1", plan, completed_step_ids={"s1"}, max_workers=4
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert set(r.step_id for r in results) == {"s2", "s3"}
|
||||
assert all(r.success for r in results)
|
||||
|
||||
|
||||
class TestOrchestratorBatchRouting:
|
||||
"""Test orchestrator batch routing."""
|
||||
|
||||
def test_route_messages_batch(self):
|
||||
"""Batch routing returns responses in order."""
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
|
||||
class EchoAgent:
|
||||
identity = "echo"
|
||||
|
||||
def handle_message(self, env):
|
||||
return AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender="echo",
|
||||
recipient=env.message.sender,
|
||||
intent="echo_reply",
|
||||
payload={"orig": env.message.payload.get("n")},
|
||||
),
|
||||
)
|
||||
|
||||
bus = EventBus()
|
||||
state = StateManager()
|
||||
orch = Orchestrator(bus, state)
|
||||
orch.register_agent("echo", EchoAgent())
|
||||
|
||||
envelopes = [
|
||||
AgentMessageEnvelope(
|
||||
message=AgentMessage(sender="c", recipient="echo", intent="test", payload={"n": i}),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
responses = orch.route_messages_batch(envelopes)
|
||||
|
||||
assert len(responses) == 5
|
||||
for i, r in enumerate(responses):
|
||||
assert r is not None
|
||||
assert r.message.payload["orig"] == i
|
||||
Reference in New Issue
Block a user