"""In-memory store for task state and execution traces; replaceable with persistent backend.""" from __future__ import annotations from collections import defaultdict from typing import Any, TYPE_CHECKING from fusionagi.schemas.task import Task, TaskState from fusionagi._logger import logger if TYPE_CHECKING: from fusionagi.core.persistence import StateBackend class StateManager: """ Manages task state and execution traces. Supports optional persistent backend via dependency injection. When a backend is provided, all operations are persisted. In-memory cache is always maintained for fast access. """ def __init__(self, backend: StateBackend | None = None) -> None: """ Initialize StateManager with optional persistence backend. Args: backend: Optional StateBackend for persistence. If None, uses in-memory only. """ self._backend = backend self._tasks: dict[str, Task] = {} self._traces: dict[str, list[dict[str, Any]]] = defaultdict(list) def get_task(self, task_id: str) -> Task | None: """Return the task by id or None. Checks memory first, then backend.""" if task_id in self._tasks: return self._tasks[task_id] if self._backend: task = self._backend.get_task(task_id) if task: self._tasks[task_id] = task return task return None def set_task(self, task: Task) -> None: """Store or update a task in memory and backend.""" self._tasks[task.task_id] = task if self._backend: self._backend.set_task(task) logger.debug("Task set", extra={"task_id": task.task_id}) def get_task_state(self, task_id: str) -> TaskState | None: """Return current task state or None if task unknown.""" task = self.get_task(task_id) return task.state if task else None def set_task_state(self, task_id: str, state: TaskState) -> None: """Update task state; creates no task if missing.""" if task_id in self._tasks: self._tasks[task_id].state = state if self._backend: self._backend.set_task_state(task_id, state) logger.debug("Task state set", extra={"task_id": task_id, "state": state.value}) elif self._backend: # Task might be in backend but not in memory task = self._backend.get_task(task_id) if task: task.state = state self._tasks[task_id] = task self._backend.set_task_state(task_id, state) logger.debug("Task state set", extra={"task_id": task_id, "state": state.value}) def append_trace(self, task_id: str, entry: dict[str, Any]) -> None: """Append an entry to the execution trace for a task.""" self._traces[task_id].append(entry) if self._backend: self._backend.append_trace(task_id, entry) tool = entry.get("tool") or entry.get("step") or "entry" logger.debug("Trace appended", extra={"task_id": task_id, "entry_key": tool}) def get_trace(self, task_id: str) -> list[dict[str, Any]]: """Return the execution trace for a task (copy). Checks backend if not in memory.""" if task_id in self._traces and self._traces[task_id]: return list(self._traces[task_id]) if self._backend: trace = self._backend.get_trace(task_id) if trace: self._traces[task_id] = list(trace) return trace return list(self._traces.get(task_id, [])) def clear_task(self, task_id: str) -> None: """Remove task and its trace (for tests or cleanup). Does not clear backend.""" self._tasks.pop(task_id, None) self._traces.pop(task_id, None) def list_tasks(self, state: TaskState | None = None) -> list[Task]: """Return all tasks, optionally filtered by state. When a persistence backend is configured, only tasks currently loaded in memory are returned; tasks that exist only in the backend are not included. """ tasks = list(self._tasks.values()) if state is not None: tasks = [t for t in tasks if t.state == state] return tasks def task_count(self) -> int: """Return total number of tasks in memory.""" return len(self._tasks)