Initial commit: add .gitignore and README
This commit is contained in:
144
fusionagi/multi_agent/parallel.py
Normal file
144
fusionagi/multi_agent/parallel.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Parallel step execution: run independent plan steps concurrently.
|
||||
|
||||
Multi-agent acceleration: steps with satisfied dependencies and no mutual
|
||||
dependencies are dispatched in parallel to maximize throughput.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from fusionagi.schemas.plan import Plan
|
||||
from fusionagi.planning import ready_steps, get_step
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelStepResult:
|
||||
"""Result of a single step execution in parallel batch."""
|
||||
|
||||
step_id: str
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
envelope: Any = None # AgentMessageEnvelope from executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteStepsCallback(Protocol):
|
||||
"""Protocol for executing a single step (e.g. via orchestrator)."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
task_id: str,
|
||||
step_id: str,
|
||||
plan: Plan,
|
||||
sender: str = "supervisor",
|
||||
) -> Any:
|
||||
"""Execute one step; return response envelope or result."""
|
||||
...
|
||||
|
||||
|
||||
def execute_steps_parallel(
|
||||
execute_fn: Callable[[str, str, Plan, str], Any],
|
||||
task_id: str,
|
||||
plan: Plan,
|
||||
completed_step_ids: set[str],
|
||||
sender: str = "supervisor",
|
||||
max_workers: int | None = None,
|
||||
) -> list[ParallelStepResult]:
|
||||
"""
|
||||
Execute all ready steps in parallel.
|
||||
|
||||
Args:
|
||||
execute_fn: Function (task_id, step_id, plan, sender) -> response.
|
||||
task_id: Task identifier.
|
||||
plan: The plan containing steps.
|
||||
completed_step_ids: Steps already completed.
|
||||
sender: Sender identity for execute messages.
|
||||
max_workers: Max parallel workers (default: unbounded for ready steps).
|
||||
|
||||
Returns:
|
||||
List of ParallelStepResult, one per step attempted.
|
||||
"""
|
||||
ready = ready_steps(plan, completed_step_ids)
|
||||
if not ready:
|
||||
return []
|
||||
|
||||
results: list[ParallelStepResult] = []
|
||||
workers = max_workers if max_workers and max_workers > 0 else len(ready)
|
||||
|
||||
def run_one(step_id: str) -> ParallelStepResult:
|
||||
try:
|
||||
response = execute_fn(task_id, step_id, plan, sender)
|
||||
if response is None:
|
||||
return ParallelStepResult(step_id=step_id, success=False, error="No response")
|
||||
# Response may be AgentMessageEnvelope with intent step_done/step_failed
|
||||
if hasattr(response, "message"):
|
||||
msg = response.message
|
||||
if msg.intent == "step_done":
|
||||
payload = msg.payload or {}
|
||||
return ParallelStepResult(
|
||||
step_id=step_id,
|
||||
success=True,
|
||||
result=payload.get("result"),
|
||||
envelope=response,
|
||||
)
|
||||
return ParallelStepResult(
|
||||
step_id=step_id,
|
||||
success=False,
|
||||
error=msg.payload.get("error", "Unknown failure") if msg.payload else "Unknown",
|
||||
envelope=response,
|
||||
)
|
||||
return ParallelStepResult(step_id=step_id, success=True, result=response)
|
||||
except Exception as e:
|
||||
logger.exception("Parallel step execution failed", extra={"step_id": step_id})
|
||||
return ParallelStepResult(step_id=step_id, success=False, error=str(e))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_step = {executor.submit(run_one, sid): sid for sid in ready}
|
||||
for future in as_completed(future_to_step):
|
||||
results.append(future.result())
|
||||
|
||||
logger.info(
|
||||
"Parallel step batch completed",
|
||||
extra={"task_id": task_id, "steps": ready, "results": len(results)},
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def execute_steps_parallel_wave(
|
||||
execute_fn: Callable[[str, str, Plan, str], Any],
|
||||
task_id: str,
|
||||
plan: Plan,
|
||||
sender: str = "supervisor",
|
||||
max_workers: int | None = None,
|
||||
) -> list[ParallelStepResult]:
|
||||
"""
|
||||
Execute plan in waves: each wave runs all ready steps in parallel,
|
||||
then advances to the next wave when deps are satisfied.
|
||||
|
||||
Returns combined results from all waves.
|
||||
"""
|
||||
completed: set[str] = set()
|
||||
all_results: list[ParallelStepResult] = []
|
||||
|
||||
while True:
|
||||
batch = execute_steps_parallel(
|
||||
execute_fn, task_id, plan, completed, sender, max_workers
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for r in batch:
|
||||
all_results.append(r)
|
||||
if r.success:
|
||||
completed.add(r.step_id)
|
||||
else:
|
||||
# On failure, stop the wave (caller can retry or handle)
|
||||
logger.warning("Step failed in wave, stopping", extra={"step_id": r.step_id})
|
||||
return all_results
|
||||
|
||||
return all_results
|
||||
Reference in New Issue
Block a user