Initial commit: add .gitignore and README
This commit is contained in:
168
fusionagi/self_improvement/correction.py
Normal file
168
fusionagi/self_improvement/correction.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Self-correction: on failure, run reflection and optionally prepare retry with feedback."""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from fusionagi.schemas.task import TaskState
|
||||
from fusionagi.schemas.recommendation import Recommendation, RecommendationKind
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class StateManagerLike(Protocol):
|
||||
"""Protocol for state manager: get task state, trace, task."""
|
||||
|
||||
def get_task_state(self, task_id: str) -> TaskState | None: ...
|
||||
def get_trace(self, task_id: str) -> list[dict[str, Any]]: ...
|
||||
def get_task(self, task_id: str) -> Any: ...
|
||||
|
||||
|
||||
class OrchestratorLike(Protocol):
|
||||
"""Protocol for orchestrator: get plan, set state (for retry)."""
|
||||
|
||||
def get_task_plan(self, task_id: str) -> dict[str, Any] | None: ...
|
||||
def set_task_state(self, task_id: str, state: TaskState, force: bool = False) -> None: ...
|
||||
def set_task_plan(self, task_id: str, plan: dict[str, Any]) -> None: ...
|
||||
|
||||
|
||||
class CriticLike(Protocol):
|
||||
"""Protocol for critic: handle_message with evaluate_request -> evaluation_ready."""
|
||||
|
||||
identity: str
|
||||
|
||||
def handle_message(self, envelope: Any) -> Any | None: ...
|
||||
|
||||
|
||||
def run_reflection_on_failure(
|
||||
critic_agent: CriticLike,
|
||||
task_id: str,
|
||||
state_manager: StateManagerLike,
|
||||
orchestrator: OrchestratorLike,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Run reflection (Critic evaluation) for a failed task.
|
||||
Returns evaluation dict or None.
|
||||
"""
|
||||
from fusionagi.schemas.messages import AgentMessage, AgentMessageEnvelope
|
||||
|
||||
trace = state_manager.get_trace(task_id)
|
||||
plan = orchestrator.get_task_plan(task_id)
|
||||
envelope = AgentMessageEnvelope(
|
||||
message=AgentMessage(
|
||||
sender="self_correction",
|
||||
recipient=critic_agent.identity,
|
||||
intent="evaluate_request",
|
||||
payload={
|
||||
"outcome": "failed",
|
||||
"trace": trace,
|
||||
"plan": plan,
|
||||
},
|
||||
),
|
||||
task_id=task_id,
|
||||
)
|
||||
response = critic_agent.handle_message(envelope)
|
||||
if not response or response.message.intent != "evaluation_ready":
|
||||
return None
|
||||
return response.message.payload.get("evaluation", {})
|
||||
|
||||
|
||||
class SelfCorrectionLoop:
|
||||
"""
|
||||
Self-correction: on failed tasks, run Critic reflection and optionally
|
||||
prepare retry by transitioning FAILED -> PENDING and storing correction context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_manager: StateManagerLike,
|
||||
orchestrator: OrchestratorLike,
|
||||
critic_agent: CriticLike,
|
||||
max_retries_per_task: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the self-correction loop.
|
||||
|
||||
Args:
|
||||
state_manager: State manager for task state and traces.
|
||||
orchestrator: Orchestrator for plan and state transitions.
|
||||
critic_agent: Critic agent for evaluate_request -> evaluation_ready.
|
||||
max_retries_per_task: Maximum retries to suggest per task (default 2).
|
||||
"""
|
||||
self._state = state_manager
|
||||
self._orchestrator = orchestrator
|
||||
self._critic = critic_agent
|
||||
self._max_retries = max_retries_per_task
|
||||
self._retry_counts: dict[str, int] = {}
|
||||
|
||||
def suggest_retry(self, task_id: str) -> tuple[bool, dict[str, Any]]:
|
||||
"""
|
||||
For a failed task, run reflection and decide whether to suggest retry.
|
||||
Returns (should_retry, correction_context).
|
||||
"""
|
||||
state = self._state.get_task_state(task_id)
|
||||
if state != TaskState.FAILED:
|
||||
return False, {}
|
||||
retries = self._retry_counts.get(task_id, 0)
|
||||
if retries >= self._max_retries:
|
||||
logger.info(
|
||||
"Self-correction: max retries reached",
|
||||
extra={"task_id": task_id, "retries": retries},
|
||||
)
|
||||
return False, {}
|
||||
evaluation = run_reflection_on_failure(
|
||||
self._critic, task_id, self._state, self._orchestrator,
|
||||
)
|
||||
if not evaluation:
|
||||
return False, {}
|
||||
suggestions = evaluation.get("suggestions", [])
|
||||
error_analysis = evaluation.get("error_analysis", [])
|
||||
should_retry = bool(suggestions or evaluation.get("score", 0) < 0.5)
|
||||
context = {
|
||||
"evaluation": evaluation,
|
||||
"suggestions": suggestions,
|
||||
"error_analysis": error_analysis,
|
||||
"retry_count": retries + 1,
|
||||
}
|
||||
return should_retry, context
|
||||
|
||||
def prepare_retry(self, task_id: str, correction_context: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Transition task from FAILED to PENDING and store correction context in plan.
|
||||
If correction_context is None, runs suggest_retry to obtain it.
|
||||
"""
|
||||
state = self._state.get_task_state(task_id)
|
||||
if state != TaskState.FAILED:
|
||||
logger.warning("Self-correction: prepare_retry called for non-failed task", extra={"task_id": task_id})
|
||||
return
|
||||
if correction_context is None:
|
||||
ok, correction_context = self.suggest_retry(task_id)
|
||||
if not ok:
|
||||
return
|
||||
plan = self._orchestrator.get_task_plan(task_id) or {}
|
||||
plan = dict(plan)
|
||||
plan["_correction_context"] = correction_context
|
||||
self._orchestrator.set_task_plan(task_id, plan)
|
||||
self._orchestrator.set_task_state(task_id, TaskState.PENDING, force=True)
|
||||
self._retry_counts[task_id] = self._retry_counts.get(task_id, 0) + 1
|
||||
logger.info("Self-correction: prepared retry", extra={"task_id": task_id, "retry_count": self._retry_counts[task_id]})
|
||||
|
||||
def correction_recommendations(self, task_id: str) -> list[Recommendation]:
|
||||
"""For a failed task, run reflection and return structured recommendations."""
|
||||
evaluation = run_reflection_on_failure(
|
||||
self._critic, task_id, self._state, self._orchestrator,
|
||||
)
|
||||
if not evaluation:
|
||||
return []
|
||||
suggestions = evaluation.get("suggestions", [])
|
||||
error_analysis = evaluation.get("error_analysis", [])
|
||||
recs: list[Recommendation] = []
|
||||
for i, s in enumerate(suggestions[:10]):
|
||||
recs.append(
|
||||
Recommendation(
|
||||
kind=RecommendationKind.NEXT_ACTION,
|
||||
title=f"Correction suggestion {i + 1}",
|
||||
description=s if isinstance(s, str) else str(s),
|
||||
payload={"raw": s, "error_analysis": error_analysis},
|
||||
source_task_id=task_id,
|
||||
priority=7,
|
||||
)
|
||||
)
|
||||
return recs
|
||||
Reference in New Issue
Block a user