"""Dvādaśa head orchestrator: parallel head dispatch, Witness coordination, second-pass.""" from __future__ import annotations import math from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError from typing import TYPE_CHECKING, Any from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope if TYPE_CHECKING: from fusionagi.core.orchestrator import Orchestrator from fusionagi.schemas.head import HeadId, HeadOutput from fusionagi.schemas.witness import FinalResponse from fusionagi.schemas.commands import ParsedCommand, UserIntent from fusionagi._logger import logger # MVP: 5 heads. Full: 11. MVP_HEADS: list[HeadId] = [ HeadId.LOGIC, HeadId.RESEARCH, HeadId.STRATEGY, HeadId.SECURITY, HeadId.SAFETY, ] ALL_CONTENT_HEADS: list[HeadId] = [h for h in HeadId if h != HeadId.WITNESS] # Heads for second-pass when risk/conflict/security SECOND_PASS_HEADS: list[HeadId] = [HeadId.SECURITY, HeadId.SAFETY, HeadId.LOGIC] # Thresholds for automatic second-pass SECOND_PASS_CONFIG: dict[str, Any] = { "min_confidence": 0.5, "max_disputed": 3, "security_keywords": ("security", "risk", "threat", "vulnerability"), } def run_heads_parallel( orchestrator: Orchestrator, task_id: str, user_prompt: str, head_ids: list[HeadId] | None = None, sender: str = "head_orchestrator", timeout_per_head: float = 60.0, min_heads_ratio: float | None = 0.6, ) -> list[HeadOutput]: """ Dispatch head_request to multiple heads in parallel; collect HeadOutput. Args: orchestrator: Orchestrator with registered head agents. task_id: Task identifier. user_prompt: User's prompt/question. head_ids: Heads to run (default: MVP_HEADS). sender: Sender identity for messages. timeout_per_head: Max seconds per head. min_heads_ratio: Return early once this fraction of heads respond (0.6 = 60%). None = wait for all heads. Reduces latency when some heads are slow. Returns: List of HeadOutput (may be partial on timeout/failure). """ heads = head_ids or MVP_HEADS heads = [h for h in heads if h != HeadId.WITNESS] if not heads: return [] envelopes = [ AgentMessageEnvelope( message=AgentMessage( sender=sender, recipient=hid.value, intent="head_request", payload={"prompt": user_prompt}, ), task_id=task_id, ) for hid in heads ] results: list[HeadOutput] = [] min_required = ( max(1, math.ceil(len(heads) * min_heads_ratio)) if min_heads_ratio is not None else len(heads) ) def run_one(env: AgentMessageEnvelope) -> HeadOutput | None: resp = orchestrator.route_message_return(env) if resp is None or resp.message.intent != "head_output": return None payload = resp.message.payload or {} ho = payload.get("head_output") if not isinstance(ho, dict): return None try: return HeadOutput.model_validate(ho) except Exception as e: logger.warning("HeadOutput parse failed", extra={"error": str(e)}) return None with ThreadPoolExecutor(max_workers=len(heads)) as ex: future_to_env = {ex.submit(run_one, env): env for env in envelopes} for future in as_completed(future_to_env, timeout=timeout_per_head * len(heads) + 5): try: out = future.result(timeout=1) if out is not None: results.append(out) if len(results) >= min_required: logger.info( "Early exit: sufficient heads responded", extra={"responded": len(results), "required": min_required}, ) break except FuturesTimeoutError: env = future_to_env[future] logger.warning("Head timeout", extra={"recipient": env.message.recipient}) except Exception as e: logger.exception("Head execution failed", extra={"error": str(e)}) return results def run_witness( orchestrator: Orchestrator, task_id: str, head_outputs: list[HeadOutput], user_prompt: str, sender: str = "head_orchestrator", ) -> FinalResponse | None: """ Route head outputs to Witness; return FinalResponse. """ envelope = AgentMessageEnvelope( message=AgentMessage( sender=sender, recipient=HeadId.WITNESS.value, intent="witness_request", payload={ "head_outputs": [h.model_dump() for h in head_outputs], "prompt": user_prompt, }, ), task_id=task_id, ) resp = orchestrator.route_message_return(envelope) if resp is None or resp.message.intent != "witness_output": return None payload = resp.message.payload or {} fr = payload.get("final_response") if not isinstance(fr, dict): return None try: return FinalResponse.model_validate(fr) except Exception as e: logger.warning("FinalResponse parse failed", extra={"error": str(e)}) return None def run_second_pass( orchestrator: Orchestrator, task_id: str, user_prompt: str, initial_outputs: list[HeadOutput], head_ids: list[HeadId] | None = None, timeout_per_head: float = 60.0, ) -> list[HeadOutput]: """ Run second-pass heads (Security, Safety, Logic) and merge with initial outputs. Replaces outputs from second-pass heads with new ones. """ heads = head_ids or SECOND_PASS_HEADS heads = [h for h in heads if h != HeadId.WITNESS] if not heads: return initial_outputs second_outputs = run_heads_parallel( orchestrator, task_id, user_prompt, head_ids=heads, timeout_per_head=timeout_per_head, ) by_head: dict[HeadId, HeadOutput] = {o.head_id: o for o in initial_outputs} for o in second_outputs: by_head[o.head_id] = o return list(by_head.values()) def _should_run_second_pass( final: FinalResponse, force: bool = False, second_pass_config: dict[str, Any] | None = None, ) -> bool: """Check if second-pass should run based on transparency report.""" if force: return True cfg = {**SECOND_PASS_CONFIG, **(second_pass_config or {})} am = final.transparency_report.agreement_map if am.confidence_score < cfg.get("min_confidence", 0.5): return True if len(am.disputed_claims) > cfg.get("max_disputed", 3): return True sr = (final.transparency_report.safety_report or "").lower() if any(kw in sr for kw in cfg.get("security_keywords", ())): return True return False def run_dvadasa( orchestrator: Orchestrator, task_id: str, user_prompt: str, parsed: ParsedCommand | None = None, head_ids: list[HeadId] | None = None, timeout_per_head: float = 60.0, event_bus: Any | None = None, force_second_pass: bool = False, return_head_outputs: bool = False, second_pass_config: dict[str, Any] | None = None, min_heads_ratio: float | None = 0.6, ) -> FinalResponse | tuple[FinalResponse, list[HeadOutput]] | tuple[None, list[HeadOutput]] | None: """ Full Dvādaśa flow: run heads in parallel, then Witness. Args: orchestrator: Orchestrator with heads and witness registered. task_id: Task identifier. user_prompt: User's prompt (or use parsed.cleaned_prompt when HEAD_STRATEGY). parsed: Optional ParsedCommand from parse_user_input. head_ids: Override heads to run (e.g. single head for HEAD_STRATEGY). timeout_per_head: Max seconds per head. event_bus: Optional EventBus to publish dvadasa_complete. second_pass_config: Override SECOND_PASS_CONFIG (min_confidence, max_disputed, etc). min_heads_ratio: Early exit once this fraction of heads respond; None = wait all. Returns: FinalResponse or None on failure. """ prompt = user_prompt heads = head_ids if parsed: if parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id and parsed.cleaned_prompt: prompt = parsed.cleaned_prompt heads = [parsed.head_id] elif parsed.intent == UserIntent.HEAD_STRATEGY and parsed.head_id: heads = [parsed.head_id] if heads is None: heads = select_heads_for_complexity(prompt) head_outputs = run_heads_parallel( orchestrator, task_id, prompt, head_ids=heads, timeout_per_head=timeout_per_head, ) if not head_outputs: logger.warning("No head outputs; cannot run Witness") return (None, []) if return_head_outputs else None final = run_witness(orchestrator, task_id, head_outputs, prompt) if final and ( force_second_pass or _should_run_second_pass( final, force=force_second_pass, second_pass_config=second_pass_config ) ): head_outputs = run_second_pass( orchestrator, task_id, prompt, head_outputs, timeout_per_head=timeout_per_head, ) final = run_witness(orchestrator, task_id, head_outputs, prompt) if final and event_bus: try: event_bus.publish( "dvadasa_complete", { "task_id": task_id, "final_response": final.model_dump(), "head_count": len(head_outputs), }, ) except Exception as e: logger.warning("Failed to publish dvadasa_complete", extra={"error": str(e)}) if return_head_outputs: return (final, head_outputs) return final def extract_sources_from_head_outputs(head_outputs: list[HeadOutput]) -> list[dict[str, Any]]: """Extract citations from head outputs for SOURCES command.""" sources: list[dict[str, Any]] = [] seen: set[tuple[str, str]] = set() for ho in head_outputs: for claim in ho.claims: for ev in claim.evidence: key = (ho.head_id.value, ev.source_id or "") if key in seen or not ev.source_id: continue seen.add(key) sources.append({ "head_id": ho.head_id.value, "source_id": ev.source_id, "excerpt": ev.excerpt or "", "confidence": ev.confidence, }) return sources def select_heads_for_complexity( prompt: str, mvp_heads: list[HeadId] = MVP_HEADS, all_heads: list[HeadId] | None = None, ) -> list[HeadId]: """ Dynamic routing: simple prompts use fewer heads. Heuristic: long prompt or keywords => all heads. """ all_heads = all_heads or ALL_CONTENT_HEADS prompt_lower = prompt.lower() complex_keywords = ( "security", "risk", "architecture", "scalability", "compliance", "critical", "production", "audit", "privacy", "sensitive", ) if len(prompt.split()) > 50 or any(kw in prompt_lower for kw in complex_keywords): return all_heads return mvp_heads