Files

87 lines
2.4 KiB
Python
Raw Permalink Normal View History

"""GPU-accelerated semantic search for memory subsystems.
Provides vector similarity search using GPU-accelerated embeddings
for SemanticGraphMemory and EpisodicMemory.
"""
from __future__ import annotations
from typing import Any
from fusionagi._logger import logger
from fusionagi.schemas.atomic import AtomicSemanticUnit
def semantic_search(
query: str,
units: list[AtomicSemanticUnit],
top_k: int = 10,
) -> list[tuple[AtomicSemanticUnit, float]]:
"""Search atomic semantic units by vector similarity using GPU.
Args:
query: Query text to search for.
units: List of atomic semantic units to search within.
top_k: Number of top results to return.
Returns:
List of (unit, similarity_score) tuples sorted by score descending.
"""
if not units:
return []
try:
from fusionagi.gpu.tensor_similarity import nearest_neighbors
corpus = [u.content for u in units]
results = nearest_neighbors([query], corpus, top_k=top_k)
if not results or not results[0]:
return []
return [(units[idx], score) for idx, score in results[0] if idx < len(units)]
except ImportError:
return _cpu_fallback_search(query, units, top_k)
def _cpu_fallback_search(
query: str,
units: list[AtomicSemanticUnit],
top_k: int,
) -> list[tuple[AtomicSemanticUnit, float]]:
"""CPU fallback: simple word-overlap similarity."""
query_words = set(query.lower().split())
scored: list[tuple[AtomicSemanticUnit, float]] = []
for unit in units:
unit_words = set(unit.content.lower().split())
if not unit_words:
continue
overlap = len(query_words & unit_words)
score = overlap / max(len(query_words | unit_words), 1)
scored.append((unit, score))
scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_k]
def batch_embed_units(
units: list[AtomicSemanticUnit],
) -> Any:
"""Embed a batch of atomic semantic units using GPU.
Args:
units: Units to embed.
Returns:
Embedding tensor (backend-specific type).
"""
try:
from fusionagi.gpu.backend import get_backend
be = get_backend()
texts = [u.content for u in units]
return be.embed_texts(texts)
except ImportError:
logger.debug("GPU not available for batch embedding")
return None