Initial commit: add .gitignore and README
This commit is contained in:
226
fusionagi/memory/episodic.py
Normal file
226
fusionagi/memory/episodic.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Episodic memory: append-only log of task/step outcomes; query by task_id or time range.
|
||||
|
||||
Episodic memory stores historical records of agent actions and outcomes:
|
||||
- Task execution traces
|
||||
- Step outcomes (success/failure)
|
||||
- Tool invocation results
|
||||
- Decision points and their outcomes
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi._time import utc_now_iso
|
||||
|
||||
|
||||
class EpisodicMemory:
|
||||
"""
|
||||
Append-only log of task and step outcomes.
|
||||
|
||||
Features:
|
||||
- Time-stamped event logging
|
||||
- Query by task ID
|
||||
- Query by time range
|
||||
- Query by event type
|
||||
- Statistical summaries
|
||||
- Memory size limits with optional archival
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 10000) -> None:
|
||||
"""
|
||||
Initialize episodic memory.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum entries before oldest are archived/removed.
|
||||
"""
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
self._by_task: dict[str, list[int]] = {} # task_id -> indices into _entries
|
||||
self._by_type: dict[str, list[int]] = {} # event_type -> indices
|
||||
self._max_entries = max_entries
|
||||
self._archived_count = 0
|
||||
|
||||
def append(
|
||||
self,
|
||||
task_id: str,
|
||||
event: dict[str, Any],
|
||||
event_type: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Append an episodic entry.
|
||||
|
||||
Args:
|
||||
task_id: Task identifier this event belongs to.
|
||||
event: Event data dictionary.
|
||||
event_type: Optional event type for categorization (e.g., "step_done", "tool_call").
|
||||
|
||||
Returns:
|
||||
Index of the appended entry.
|
||||
"""
|
||||
# Enforce size limits
|
||||
if len(self._entries) >= self._max_entries:
|
||||
self._archive_oldest(self._max_entries // 10)
|
||||
|
||||
# Add metadata
|
||||
entry = {
|
||||
**event,
|
||||
"task_id": task_id,
|
||||
"timestamp": event.get("timestamp", time.monotonic()),
|
||||
"datetime": event.get("datetime", utc_now_iso()),
|
||||
}
|
||||
|
||||
if event_type:
|
||||
entry["event_type"] = event_type
|
||||
|
||||
idx = len(self._entries)
|
||||
self._entries.append(entry)
|
||||
|
||||
# Index by task
|
||||
self._by_task.setdefault(task_id, []).append(idx)
|
||||
|
||||
# Index by type if provided
|
||||
etype = event_type or event.get("type") or event.get("event_type")
|
||||
if etype:
|
||||
self._by_type.setdefault(etype, []).append(idx)
|
||||
|
||||
return idx
|
||||
|
||||
def get_by_task(self, task_id: str, limit: int | None = None) -> list[dict[str, Any]]:
|
||||
"""Return all entries for a task (copy), optionally limited."""
|
||||
indices = self._by_task.get(task_id, [])
|
||||
if limit:
|
||||
indices = indices[-limit:]
|
||||
return [self._entries[i].copy() for i in indices]
|
||||
|
||||
def get_by_type(self, event_type: str, limit: int | None = None) -> list[dict[str, Any]]:
|
||||
"""Return entries of a specific type."""
|
||||
indices = self._by_type.get(event_type, [])
|
||||
if limit:
|
||||
indices = indices[-limit:]
|
||||
return [self._entries[i].copy() for i in indices]
|
||||
|
||||
def get_recent(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||
"""Return most recent entries (copy)."""
|
||||
return [e.copy() for e in self._entries[-limit:]]
|
||||
|
||||
def get_by_time_range(
|
||||
self,
|
||||
start_timestamp: float | None = None,
|
||||
end_timestamp: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return entries within a time range (using monotonic timestamps).
|
||||
|
||||
Args:
|
||||
start_timestamp: Start of range (inclusive).
|
||||
end_timestamp: End of range (inclusive).
|
||||
limit: Maximum entries to return.
|
||||
"""
|
||||
results = []
|
||||
for entry in self._entries:
|
||||
ts = entry.get("timestamp", 0)
|
||||
if start_timestamp and ts < start_timestamp:
|
||||
continue
|
||||
if end_timestamp and ts > end_timestamp:
|
||||
continue
|
||||
results.append(entry.copy())
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def query(
|
||||
self,
|
||||
filter_fn: Callable[[dict[str, Any]], bool],
|
||||
limit: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query entries using a custom filter function.
|
||||
|
||||
Args:
|
||||
filter_fn: Function that returns True for entries to include.
|
||||
limit: Maximum entries to return.
|
||||
"""
|
||||
results = []
|
||||
for entry in self._entries:
|
||||
if filter_fn(entry):
|
||||
results.append(entry.copy())
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
def get_task_summary(self, task_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get a summary of episodes for a task.
|
||||
|
||||
Returns statistics like count, first/last timestamps, event types.
|
||||
"""
|
||||
entries = self.get_by_task(task_id)
|
||||
if not entries:
|
||||
return {"task_id": task_id, "count": 0}
|
||||
|
||||
event_types: dict[str, int] = {}
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for entry in entries:
|
||||
etype = entry.get("event_type") or entry.get("type") or "unknown"
|
||||
event_types[etype] = event_types.get(etype, 0) + 1
|
||||
|
||||
if entry.get("success"):
|
||||
success_count += 1
|
||||
elif entry.get("error") or entry.get("success") is False:
|
||||
failure_count += 1
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"count": len(entries),
|
||||
"first_timestamp": entries[0].get("datetime"),
|
||||
"last_timestamp": entries[-1].get("datetime"),
|
||||
"event_types": event_types,
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
}
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""Get overall memory statistics."""
|
||||
return {
|
||||
"total_entries": len(self._entries),
|
||||
"archived_entries": self._archived_count,
|
||||
"task_count": len(self._by_task),
|
||||
"event_type_count": len(self._by_type),
|
||||
"event_types": list(self._by_type.keys()),
|
||||
}
|
||||
|
||||
def _archive_oldest(self, count: int) -> None:
|
||||
"""Archive/remove oldest entries to enforce size limits."""
|
||||
if count <= 0 or count >= len(self._entries):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Archiving episodic memory entries",
|
||||
extra={"count": count, "total": len(self._entries)},
|
||||
)
|
||||
|
||||
# Remove oldest entries
|
||||
self._entries = self._entries[count:]
|
||||
self._archived_count += count
|
||||
|
||||
# Rebuild indices (entries shifted)
|
||||
self._by_task = {}
|
||||
self._by_type = {}
|
||||
for idx, entry in enumerate(self._entries):
|
||||
task_id = entry.get("task_id")
|
||||
if task_id:
|
||||
self._by_task.setdefault(task_id, []).append(idx)
|
||||
|
||||
etype = entry.get("event_type") or entry.get("type")
|
||||
if etype:
|
||||
self._by_type.setdefault(etype, []).append(idx)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all entries (for tests)."""
|
||||
self._entries.clear()
|
||||
self._by_task.clear()
|
||||
self._by_type.clear()
|
||||
self._archived_count = 0
|
||||
Reference in New Issue
Block a user