"""Self-correction: on failure, run reflection and optionally prepare retry with feedback.""" from typing import Any, Protocol from fusionagi.schemas.task import TaskState from fusionagi.schemas.recommendation import Recommendation, RecommendationKind from fusionagi._logger import logger class StateManagerLike(Protocol): """Protocol for state manager: get task state, trace, task.""" def get_task_state(self, task_id: str) -> TaskState | None: ... def get_trace(self, task_id: str) -> list[dict[str, Any]]: ... def get_task(self, task_id: str) -> Any: ... class OrchestratorLike(Protocol): """Protocol for orchestrator: get plan, set state (for retry).""" def get_task_plan(self, task_id: str) -> dict[str, Any] | None: ... def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None: ... def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None: ... class CriticLike(Protocol): """Protocol for critic: handle_message with evaluate_request -> evaluation_ready.""" identity: str def handle_message(self, envelope: Any) -> Any | None: ... def run_reflection_on_failure( critic_agent: CriticLike, task_id: str, state_manager: StateManagerLike, orchestrator: OrchestratorLike, ) -> dict[str, Any] | None: """ Run reflection (Critic evaluation) for a failed task. Returns evaluation dict or None. """ from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope trace = state_manager.get_trace(task_id) plan = orchestrator.get_task_plan(task_id) envelope = AgentMessageEnvelope( message=AgentMessage( sender="self_correction", recipient=critic_agent.identity, intent="evaluate_request", payload={ "outcome": "failed", "trace": trace, "plan": plan, }, ), task_id=task_id, ) response = critic_agent.handle_message(envelope) if not response or response.message.intent != "evaluation_ready": return None return response.message.payload.get("evaluation", {}) class SelfCorrectionLoop: """ Self-correction: on failed tasks, run Critic reflection and optionally prepare retry by transitioning FAILED -> PENDING and storing correction context. """ def __init__( self, state_manager: StateManagerLike, orchestrator: OrchestratorLike, critic_agent: CriticLike, max_retries_per_task: int = 2, ) -> None: """ Initialize the self-correction loop. Args: state_manager: State manager for task state and traces. orchestrator: Orchestrator for plan and state transitions. critic_agent: Critic agent for evaluate_request -> evaluation_ready. max_retries_per_task: Maximum retries to suggest per task (default 2). """ self._state = state_manager self._orchestrator = orchestrator self._critic = critic_agent self._max_retries = max_retries_per_task self._retry_counts: dict[str, int] = {} def suggest_retry(self, task_id: str) -> tuple[bool, dict[str, Any]]: """ For a failed task, run reflection and decide whether to suggest retry. Returns (should_retry, correction_context). """ state = self._state.get_task_state(task_id) if state != TaskState.FAILED: return False, {} retries = self._retry_counts.get(task_id, 0) if retries >= self._max_retries: logger.info( "Self-correction: max retries reached", extra={"task_id": task_id, "retries": retries}, ) return False, {} evaluation = run_reflection_on_failure( self._critic, task_id, self._state, self._orchestrator, ) if not evaluation: return False, {} suggestions = evaluation.get("suggestions", []) error_analysis = evaluation.get("error_analysis", []) should_retry = bool(suggestions or evaluation.get("score", 0) < 0.5) context = { "evaluation": evaluation, "suggestions": suggestions, "error_analysis": error_analysis, "retry_count": retries + 1, } return should_retry, context def prepare_retry(self, task_id: str, correction_context: dict[str, Any] | None = None) -> None: """ Transition task from FAILED to PENDING and store correction context in plan. If correction_context is None, runs suggest_retry to obtain it. """ state = self._state.get_task_state(task_id) if state != TaskState.FAILED: logger.warning("Self-correction: prepare_retry called for non-failed task", extra={"task_id": task_id}) return if correction_context is None: ok, correction_context = self.suggest_retry(task_id) if not ok: return plan = self._orchestrator.get_task_plan(task_id) or {} plan = dict(plan) plan["_correction_context"] = correction_context self._orchestrator.set_task_plan(task_id, plan) self._orchestrator.set_task_state(task_id, TaskState.PENDING, force=True) self._retry_counts[task_id] = self._retry_counts.get(task_id, 0) + 1 logger.info("Self-correction: prepared retry", extra={"task_id": task_id, "retry_count": self._retry_counts[task_id]}) def correction_recommendations(self, task_id: str) -> list[Recommendation]: """For a failed task, run reflection and return structured recommendations.""" evaluation = run_reflection_on_failure( self._critic, task_id, self._state, self._orchestrator, ) if not evaluation: return [] suggestions = evaluation.get("suggestions", []) error_analysis = evaluation.get("error_analysis", []) recs: list[Recommendation] = [] for i, s in enumerate(suggestions[:10]): recs.append( Recommendation( kind=RecommendationKind.NEXT_ACTION, title=f"Correction suggestion {i + 1}", description=s if isinstance(s, str) else str(s), payload={"raw": s, "error_analysis": error_analysis}, source_task_id=task_id, priority=7, ) ) return recs