Files
FusionAGI/fusionagi/core/memory_backend.py

69 lines
2.3 KiB
Python
Raw Normal View History

"""In-memory state backend for task persistence.
Useful for testing and development when no database is needed.
"""
from __future__ import annotations
from typing import Any
from fusionagi.core.persistence import StateBackend
from fusionagi.schemas.task import Task, TaskState
class InMemoryStateBackend(StateBackend):
"""In-memory implementation of StateBackend.
All data is lost on process restart. Use SQLiteStateBackend
or a Postgres-backed backend for production persistence.
"""
def __init__(self) -> None:
self._tasks: dict[str, Task] = {}
self._traces: dict[str, list[dict[str, Any]]] = {}
def get_task(self, task_id: str) -> Task | None:
"""Load task by id."""
return self._tasks.get(task_id)
def set_task(self, task: Task) -> None:
"""Save task."""
self._tasks[task.task_id] = task
def get_task_state(self, task_id: str) -> TaskState | None:
"""Return current task state or None if task unknown."""
task = self._tasks.get(task_id)
return task.state if task else None
def set_task_state(self, task_id: str, state: TaskState) -> None:
"""Update task state; creates no task if missing."""
task = self._tasks.get(task_id)
if task is not None:
self._tasks[task_id] = task.model_copy(update={"state": state})
def append_trace(self, task_id: str, entry: dict[str, Any]) -> None:
"""Append trace entry."""
if task_id not in self._traces:
self._traces[task_id] = []
self._traces[task_id].append(entry)
def get_trace(self, task_id: str) -> list[dict[str, Any]]:
"""Load trace for task."""
return list(self._traces.get(task_id, []))
def list_tasks(self, state: TaskState | None = None, limit: int = 100) -> list[Task]:
"""List tasks, optionally filtered by state."""
tasks = list(self._tasks.values())
if state is not None:
tasks = [t for t in tasks if t.state == state]
return tasks[:limit]
def delete_task(self, task_id: str) -> bool:
"""Delete a task and its traces."""
self._traces.pop(task_id, None)
return self._tasks.pop(task_id, None) is not None
def count_tasks(self) -> int:
"""Return total task count."""
return len(self._tasks)