"""Context sharding: cluster atomic units by semantic similarity or domain.""" from __future__ import annotations import re import uuid from dataclasses import dataclass, field from typing import Any from fusionagi.schemas.atomic import AtomicSemanticUnit @dataclass class Shard: """A cluster of atomic units with optional summary and embedding.""" shard_id: str = field(default_factory=lambda: f"shard_{uuid.uuid4().hex[:12]}") unit_ids: list[str] = field(default_factory=list) summary: str = "" embedding: list[float] | None = None metadata: dict[str, Any] = field(default_factory=dict) def _extract_keywords(text: str) -> set[str]: """Extract keywords for clustering.""" content = " ".join(text.lower().split()) return set(re.findall(r"\b[a-z0-9]{3,}\b", content)) def _keyword_similarity(a: set[str], b: set[str]) -> float: """Jaccard similarity between keyword sets.""" if not a and not b: return 1.0 inter = len(a & b) union = len(a | b) return inter / union if union else 0.0 def _cluster_by_keywords( units: list[AtomicSemanticUnit], max_cluster_size: int, ) -> list[list[AtomicSemanticUnit]]: """Cluster units by keyword overlap (greedy).""" if not units: return [] if len(units) <= max_cluster_size: return [units] unit_keywords: list[set[str]] = [_extract_keywords(u.content) for u in units] clusters: list[list[int]] = [] assigned: set[int] = set() for i in range(len(units)): if i in assigned: continue cluster = [i] assigned.add(i) for j in range(i + 1, len(units)): if j in assigned or len(cluster) >= max_cluster_size: continue sim = _keyword_similarity(unit_keywords[i], unit_keywords[j]) if sim > 0.1: cluster.append(j) assigned.add(j) clusters.append(cluster) return [[units[idx] for idx in c] for c in clusters] def shard_context( units: list[AtomicSemanticUnit], max_cluster_size: int = 20, ) -> list[Shard]: """Shard atomic units into clusters by semantic similarity.""" clusters = _cluster_by_keywords(units, max_cluster_size) shards: list[Shard] = [] for cluster in clusters: unit_ids = [u.unit_id for u in cluster] summary_parts = [u.content[:80] for u in cluster[:3]] summary = "; ".join(summary_parts) + ("..." if len(cluster) > 3 else "") shards.append(Shard(unit_ids=unit_ids, summary=summary, metadata={"unit_count": len(cluster)})) return shards