"""Episodic memory: append-only log of task/step outcomes; query by task_id or time range. Episodic memory stores historical records of agent actions and outcomes: - Task execution traces - Step outcomes (success/failure) - Tool invocation results - Decision points and their outcomes """ import time from typing import Any, Callable, Iterator from fusionagi._logger import logger from fusionagi._time import utc_now_iso class EpisodicMemory: """ Append-only log of task and step outcomes. Features: - Time-stamped event logging - Query by task ID - Query by time range - Query by event type - Statistical summaries - Memory size limits with optional archival """ def __init__(self, max_entries: int = 10000) -> None: """ Initialize episodic memory. Args: max_entries: Maximum entries before oldest are archived/removed. """ self._entries: list[dict[str, Any]] = [] self._by_task: dict[str, list[int]] = {} # task_id -> indices into _entries self._by_type: dict[str, list[int]] = {} # event_type -> indices self._max_entries = max_entries self._archived_count = 0 def append( self, task_id: str, event: dict[str, Any], event_type: str | None = None, ) -> int: """ Append an episodic entry. Args: task_id: Task identifier this event belongs to. event: Event data dictionary. event_type: Optional event type for categorization (e.g., "step_done", "tool_call"). Returns: Index of the appended entry. """ # Enforce size limits if len(self._entries) >= self._max_entries: self._archive_oldest(self._max_entries // 10) # Add metadata entry = { **event, "task_id": task_id, "timestamp": event.get("timestamp", time.monotonic()), "datetime": event.get("datetime", utc_now_iso()), } if event_type: entry["event_type"] = event_type idx = len(self._entries) self._entries.append(entry) # Index by task self._by_task.setdefault(task_id, []).append(idx) # Index by type if provided etype = event_type or event.get("type") or event.get("event_type") if etype: self._by_type.setdefault(etype, []).append(idx) return idx def get_by_task(self, task_id: str, limit: int | None = None) -> list[dict[str, Any]]: """Return all entries for a task (copy), optionally limited.""" indices = self._by_task.get(task_id, []) if limit: indices = indices[-limit:] return [self._entries[i].copy() for i in indices] def get_by_type(self, event_type: str, limit: int | None = None) -> list[dict[str, Any]]: """Return entries of a specific type.""" indices = self._by_type.get(event_type, []) if limit: indices = indices[-limit:] return [self._entries[i].copy() for i in indices] def get_recent(self, limit: int = 100) -> list[dict[str, Any]]: """Return most recent entries (copy).""" return [e.copy() for e in self._entries[-limit:]] def get_by_time_range( self, start_timestamp: float | None = None, end_timestamp: float | None = None, limit: int | None = None, ) -> list[dict[str, Any]]: """ Return entries within a time range (using monotonic timestamps). Args: start_timestamp: Start of range (inclusive). end_timestamp: End of range (inclusive). limit: Maximum entries to return. """ results = [] for entry in self._entries: ts = entry.get("timestamp", 0) if start_timestamp and ts < start_timestamp: continue if end_timestamp and ts > end_timestamp: continue results.append(entry.copy()) if limit and len(results) >= limit: break return results def query( self, filter_fn: Callable[[dict[str, Any]], bool], limit: int | None = None, ) -> list[dict[str, Any]]: """ Query entries using a custom filter function. Args: filter_fn: Function that returns True for entries to include. limit: Maximum entries to return. """ results = [] for entry in self._entries: if filter_fn(entry): results.append(entry.copy()) if limit and len(results) >= limit: break return results def get_task_summary(self, task_id: str) -> dict[str, Any]: """ Get a summary of episodes for a task. Returns statistics like count, first/last timestamps, event types. """ entries = self.get_by_task(task_id) if not entries: return {"task_id": task_id, "count": 0} event_types: dict[str, int] = {} success_count = 0 failure_count = 0 for entry in entries: etype = entry.get("event_type") or entry.get("type") or "unknown" event_types[etype] = event_types.get(etype, 0) + 1 if entry.get("success"): success_count += 1 elif entry.get("error") or entry.get("success") is False: failure_count += 1 return { "task_id": task_id, "count": len(entries), "first_timestamp": entries[0].get("datetime"), "last_timestamp": entries[-1].get("datetime"), "event_types": event_types, "success_count": success_count, "failure_count": failure_count, } def get_statistics(self) -> dict[str, Any]: """Get overall memory statistics.""" return { "total_entries": len(self._entries), "archived_entries": self._archived_count, "task_count": len(self._by_task), "event_type_count": len(self._by_type), "event_types": list(self._by_type.keys()), } def _archive_oldest(self, count: int) -> None: """Archive/remove oldest entries to enforce size limits.""" if count <= 0 or count >= len(self._entries): return logger.info( "Archiving episodic memory entries", extra={"count": count, "total": len(self._entries)}, ) # Remove oldest entries self._entries = self._entries[count:] self._archived_count += count # Rebuild indices (entries shifted) self._by_task = {} self._by_type = {} for idx, entry in enumerate(self._entries): task_id = entry.get("task_id") if task_id: self._by_task.setdefault(task_id, []).append(idx) etype = entry.get("event_type") or entry.get("type") if etype: self._by_type.setdefault(etype, []).append(idx) def clear(self) -> None: """Clear all entries (for tests).""" self._entries.clear() self._by_task.clear() self._by_type.clear() self._archived_count = 0