Initial commit: add .gitignore and README
This commit is contained in:
194
fusionagi/multi_agent/consensus_engine.py
Normal file
194
fusionagi/multi_agent/consensus_engine.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Consensus engine: claim collection, deduplication, conflict detection, scoring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.head import HeadId, HeadOutput, HeadClaim
|
||||
from fusionagi.schemas.witness import AgreementMap
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectedClaim:
|
||||
"""Claim with source head and metadata for consensus."""
|
||||
|
||||
claim_text: str
|
||||
confidence: float
|
||||
head_id: HeadId
|
||||
evidence_count: int
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def _normalize_text(s: str) -> str:
|
||||
"""Normalize for duplicate detection."""
|
||||
return " ".join(s.lower().split())
|
||||
|
||||
|
||||
def _are_similar(a: str, b: str, threshold: float = 0.9) -> bool:
|
||||
"""Simple similarity: exact match or one contains the other (normalized)."""
|
||||
na, nb = _normalize_text(a), _normalize_text(b)
|
||||
if na == nb:
|
||||
return True
|
||||
if len(na) < 10 or len(nb) < 10:
|
||||
return na == nb
|
||||
# Jaccard-like: word overlap
|
||||
wa, wb = set(na.split()), set(nb.split())
|
||||
inter = len(wa & wb)
|
||||
union = len(wa | wb)
|
||||
if union == 0:
|
||||
return False
|
||||
return (inter / union) >= threshold
|
||||
|
||||
|
||||
def _looks_contradictory(a: str, b: str) -> bool:
|
||||
"""Heuristic: same subject with opposite polarity indicators."""
|
||||
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
|
||||
na, nb = _normalize_text(a), _normalize_text(b)
|
||||
wa, wb = set(na.split()), set(nb.split())
|
||||
# If one has neg and the other doesn't, and they share significant overlap
|
||||
a_neg = bool(wa & neg_words)
|
||||
b_neg = bool(wb & neg_words)
|
||||
if a_neg != b_neg:
|
||||
overlap = len(wa & wb) / max(len(wa), 1)
|
||||
if overlap > 0.3:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def collect_claims(outputs: list[HeadOutput]) -> list[CollectedClaim]:
|
||||
"""Flatten all head claims with source metadata."""
|
||||
collected: list[CollectedClaim] = []
|
||||
for out in outputs:
|
||||
for c in out.claims:
|
||||
collected.append(
|
||||
CollectedClaim(
|
||||
claim_text=c.claim_text,
|
||||
confidence=c.confidence,
|
||||
head_id=out.head_id,
|
||||
evidence_count=len(c.evidence),
|
||||
raw={
|
||||
"claim_text": c.claim_text,
|
||||
"confidence": c.confidence,
|
||||
"head_id": out.head_id.value,
|
||||
"evidence_count": len(c.evidence),
|
||||
"assumptions": c.assumptions,
|
||||
},
|
||||
)
|
||||
)
|
||||
return collected
|
||||
|
||||
|
||||
def run_consensus(
|
||||
outputs: list[HeadOutput],
|
||||
head_weights: dict[HeadId, float] | None = None,
|
||||
confidence_threshold: float = 0.5,
|
||||
) -> AgreementMap:
|
||||
"""
|
||||
Run consensus: deduplicate, detect conflicts, score, produce AgreementMap.
|
||||
|
||||
Args:
|
||||
outputs: HeadOutput from all content heads.
|
||||
head_weights: Optional per-head reliability weights (default 1.0).
|
||||
confidence_threshold: Minimum confidence for agreed claim.
|
||||
|
||||
Returns:
|
||||
AgreementMap with agreed_claims, disputed_claims, confidence_score.
|
||||
"""
|
||||
if not outputs:
|
||||
return AgreementMap(
|
||||
agreed_claims=[],
|
||||
disputed_claims=[],
|
||||
confidence_score=0.0,
|
||||
)
|
||||
|
||||
weights = head_weights or {h: 1.0 for h in HeadId}
|
||||
collected = collect_claims(outputs)
|
||||
|
||||
# Group by similarity (merge near-duplicates)
|
||||
merged: list[CollectedClaim] = []
|
||||
used: set[int] = set()
|
||||
for i, ca in enumerate(collected):
|
||||
if i in used:
|
||||
continue
|
||||
group = [ca]
|
||||
used.add(i)
|
||||
for j, cb in enumerate(collected):
|
||||
if j in used:
|
||||
continue
|
||||
if _are_similar(ca.claim_text, cb.claim_text) and not _looks_contradictory(ca.claim_text, cb.claim_text):
|
||||
group.append(cb)
|
||||
used.add(j)
|
||||
# Aggregate: weighted avg confidence, combine heads
|
||||
if len(group) == 1:
|
||||
c = group[0]
|
||||
score = c.confidence * weights.get(c.head_id, 1.0)
|
||||
if c.evidence_count > 0:
|
||||
score *= 1.1 # boost for citations
|
||||
merged.append(
|
||||
CollectedClaim(
|
||||
claim_text=c.claim_text,
|
||||
confidence=score,
|
||||
head_id=c.head_id,
|
||||
evidence_count=c.evidence_count,
|
||||
raw={**c.raw, "aggregated_confidence": score, "supporting_heads": [c.head_id.value]},
|
||||
)
|
||||
)
|
||||
else:
|
||||
total_conf = sum(g.confidence * weights.get(g.head_id, 1.0) for g in group)
|
||||
avg_conf = total_conf / len(group)
|
||||
evidence_boost = 1.1 if any(g.evidence_count > 0 for g in group) else 1.0
|
||||
score = min(1.0, avg_conf * evidence_boost)
|
||||
merged.append(
|
||||
CollectedClaim(
|
||||
claim_text=group[0].claim_text,
|
||||
confidence=score,
|
||||
head_id=group[0].head_id,
|
||||
evidence_count=sum(g.evidence_count for g in group),
|
||||
raw={
|
||||
"claim_text": group[0].claim_text,
|
||||
"aggregated_confidence": score,
|
||||
"supporting_heads": [g.head_id.value for g in group],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Conflict detection
|
||||
agreed: list[dict[str, Any]] = []
|
||||
disputed: list[dict[str, Any]] = []
|
||||
|
||||
for c in merged:
|
||||
in_conflict = False
|
||||
for d in merged:
|
||||
if c is d:
|
||||
continue
|
||||
if _looks_contradictory(c.claim_text, d.claim_text):
|
||||
in_conflict = True
|
||||
break
|
||||
rec = {
|
||||
"claim_text": c.claim_text,
|
||||
"confidence": c.confidence,
|
||||
"supporting_heads": c.raw.get("supporting_heads", [c.head_id.value]),
|
||||
}
|
||||
if in_conflict or c.confidence < confidence_threshold:
|
||||
disputed.append(rec)
|
||||
else:
|
||||
agreed.append(rec)
|
||||
|
||||
overall_conf = (
|
||||
sum(a["confidence"] for a in agreed) / len(agreed)
|
||||
if agreed
|
||||
else 0.0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Consensus complete",
|
||||
extra={"agreed": len(agreed), "disputed": len(disputed), "confidence": overall_conf},
|
||||
)
|
||||
|
||||
return AgreementMap(
|
||||
agreed_claims=agreed,
|
||||
disputed_claims=disputed,
|
||||
confidence_score=min(1.0, overall_conf),
|
||||
)
|
||||
Reference in New Issue
Block a user