Files

148 lines
5.3 KiB
Python
Raw Permalink Normal View History

"""Session and prompt routes."""
import json
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from fusionagi.api.dependencies import get_orchestrator, get_session_store, get_event_bus, get_safety_pipeline
from fusionagi.api.websocket import handle_stream
from fusionagi.core import run_dvadasa, select_heads_for_complexity, extract_sources_from_head_outputs
from fusionagi.schemas.commands import parse_user_input, UserIntent
router = APIRouter()
def _ensure_init():
from fusionagi.api.dependencies import ensure_initialized
ensure_initialized()
@router.post("")
def create_session(user_id: str | None = None) -> dict[str, Any]:
"""Create a new session."""
_ensure_init()
store = get_session_store()
if not store:
raise HTTPException(status_code=503, detail="Session store not initialized")
session_id = str(uuid.uuid4())
store.create(session_id, user_id)
return {"session_id": session_id, "user_id": user_id}
@router.post("/{session_id}/prompt")
def submit_prompt(session_id: str, body: dict[str, Any]) -> dict[str, Any]:
"""Submit a prompt and receive FinalResponse (sync)."""
_ensure_init()
store = get_session_store()
orch = get_orchestrator()
bus = get_event_bus()
if not store or not orch:
raise HTTPException(status_code=503, detail="Service not initialized")
sess = store.get(session_id)
if not sess:
raise HTTPException(status_code=404, detail="Session not found")
prompt = body.get("prompt", "")
parsed = parse_user_input(prompt)
if not prompt or not parsed.cleaned_prompt.strip():
if parsed.intent in (UserIntent.SHOW_DISSENT, UserIntent.RERUN_RISK, UserIntent.EXPLAIN_REASONING, UserIntent.SOURCES):
hist = sess.get("history", [])
if hist:
prompt = hist[-1].get("prompt", "")
if not prompt:
raise HTTPException(status_code=400, detail="No previous prompt; provide a prompt for this command")
else:
raise HTTPException(status_code=400, detail="prompt is required")
effective_prompt = parsed.cleaned_prompt.strip() or prompt
pipeline = get_safety_pipeline()
if pipeline:
pre_result = pipeline.pre_check(effective_prompt)
if not pre_result.allowed:
raise HTTPException(status_code=400, detail=pre_result.reason or "Input moderation failed")
task_id = orch.submit_task(goal=effective_prompt[:200])
# Dynamic head selection
head_ids = select_heads_for_complexity(effective_prompt)
if parsed.intent.value == "head_strategy" and parsed.head_id:
head_ids = [parsed.head_id]
force_second = parsed.intent == UserIntent.RERUN_RISK
return_heads = parsed.intent == UserIntent.SOURCES
result = run_dvadasa(
orchestrator=orch,
task_id=task_id,
user_prompt=effective_prompt,
parsed=parsed,
head_ids=head_ids if parsed.intent.value != "normal" or body.get("use_all_heads") else None,
event_bus=bus,
force_second_pass=force_second,
return_head_outputs=return_heads,
)
if return_heads and isinstance(result, tuple):
final, head_outputs = result
else:
final = result
head_outputs = []
if not final:
raise HTTPException(status_code=500, detail="Failed to produce response")
if pipeline:
post_result = pipeline.post_check(final.final_answer)
if not post_result.passed:
raise HTTPException(
status_code=400,
detail=f"Output scan failed: {', '.join(post_result.flags)}",
)
entry = {
"prompt": effective_prompt,
"final_answer": final.final_answer,
"confidence_score": final.confidence_score,
"head_contributions": final.head_contributions,
}
store.append_history(session_id, entry)
response: dict[str, Any] = {
"task_id": task_id,
"final_answer": final.final_answer,
"transparency_report": final.transparency_report.model_dump(),
"head_contributions": final.head_contributions,
"confidence_score": final.confidence_score,
}
if parsed.intent == UserIntent.SHOW_DISSENT:
response["response_mode"] = "show_dissent"
response["disputed_claims"] = final.transparency_report.agreement_map.disputed_claims
elif parsed.intent == UserIntent.EXPLAIN_REASONING:
response["response_mode"] = "explain"
elif parsed.intent == UserIntent.SOURCES and head_outputs:
response["sources"] = extract_sources_from_head_outputs(head_outputs)
return response
@router.websocket("/{session_id}/stream")
async def stream_websocket(websocket: WebSocket, session_id: str) -> None:
"""WebSocket for streaming Dvādaśa response. Send {\"prompt\": \"...\"} to start."""
await websocket.accept()
try:
data = await websocket.receive_json()
prompt = data.get("prompt", "")
async def send_evt(evt: dict) -> None:
await websocket.send_json(evt)
await handle_stream(session_id, prompt, send_evt)
except WebSocketDisconnect:
pass
except Exception as e:
try:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass