"""Versioned thought states: snapshots, rollback, branching.""" from __future__ import annotations import time import uuid from dataclasses import dataclass, field from typing import Any from fusionagi.memory.scratchpad import ThoughtState from fusionagi.reasoning.tot import ThoughtNode from fusionagi._logger import logger @dataclass class ThoughtStateSnapshot: """Snapshot of reasoning state: tree + scratchpad.""" version_id: str = field(default_factory=lambda: f"v_{uuid.uuid4().hex[:12]}") tree_state: dict[str, Any] | None = None scratchpad_state: ThoughtState | None = None timestamp: float = field(default_factory=time.monotonic) metadata: dict[str, Any] = field(default_factory=dict) def _serialize_tree(node: ThoughtNode | None) -> dict[str, Any]: """Serialize ThoughtNode to dict.""" if node is None: return {} return { "node_id": node.node_id, "parent_id": node.parent_id, "thought": node.thought, "trace": node.trace, "score": node.score, "depth": node.depth, "unit_refs": node.unit_refs, "metadata": node.metadata, "children": [_serialize_tree(c) for c in node.children], } def _deserialize_tree(data: dict) -> ThoughtNode | None: """Deserialize dict to ThoughtNode.""" if not data: return None node = ThoughtNode( node_id=data.get("node_id", ""), parent_id=data.get("parent_id"), thought=data.get("thought", ""), trace=data.get("trace", []), score=float(data.get("score", 0)), depth=int(data.get("depth", 0)), unit_refs=list(data.get("unit_refs", [])), metadata=dict(data.get("metadata", {})), ) for c in data.get("children", []): child = _deserialize_tree(c) if child: node.children.append(child) return node class ThoughtVersioning: """Save, load, rollback, branch thought states.""" def __init__(self, max_snapshots: int = 50) -> None: self._snapshots: dict[str, ThoughtStateSnapshot] = {} self._max_snapshots = max_snapshots def save_snapshot( self, tree: ThoughtNode | None, scratchpad: ThoughtState | None, metadata: dict[str, Any] | None = None, ) -> str: """Save snapshot; return version_id.""" snapshot = ThoughtStateSnapshot( tree_state=_serialize_tree(tree) if tree else {}, scratchpad_state=scratchpad, metadata=metadata or {}, ) self._snapshots[snapshot.version_id] = snapshot if len(self._snapshots) > self._max_snapshots: oldest = min(self._snapshots.keys(), key=lambda k: self._snapshots[k].timestamp) del self._snapshots[oldest] logger.debug("Thought snapshot saved", extra={"version_id": snapshot.version_id}) return snapshot.version_id def load_snapshot( self, version_id: str, ) -> tuple[ThoughtNode | None, ThoughtState | None]: """Load snapshot; return (tree, scratchpad).""" snap = self._snapshots.get(version_id) if not snap: return None, None tree = _deserialize_tree(snap.tree_state or {}) if snap.tree_state else None return tree, snap.scratchpad_state def list_snapshots(self) -> list[dict[str, Any]]: """List available snapshots.""" return [ { "version_id": v.version_id, "timestamp": v.timestamp, "metadata": v.metadata, } for v in self._snapshots.values() ] def rollback_to( self, version_id: str, ) -> tuple[ThoughtNode | None, ThoughtState | None]: """Load and return snapshot (alias for load_snapshot).""" return self.load_snapshot(version_id) def branch_from( self, version_id: str, ) -> tuple[ThoughtNode | None, ThoughtState | None]: """Branch from snapshot (returns copy for further edits).""" tree, scratchpad = self.load_snapshot(version_id) if tree: tree = _deserialize_tree(_serialize_tree(tree)) if scratchpad: scratchpad = ThoughtState( hypotheses=list(scratchpad.hypotheses), partial_conclusions=list(scratchpad.partial_conclusions), discarded_paths=list(scratchpad.discarded_paths), metadata=dict(scratchpad.metadata), ) return tree, scratchpad