"""Session and prompt routes.""" import json import uuid from typing import Any from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from fusionagi.api.dependencies import get_orchestrator, get_session_store, get_event_bus, get_safety_pipeline from fusionagi.api.websocket import handle_stream from fusionagi.core import run_dvadasa, select_heads_for_complexity, extract_sources_from_head_outputs from fusionagi.schemas.commands import parse_user_input, UserIntent router = APIRouter() def _ensure_init(): from fusionagi.api.dependencies import ensure_initialized ensure_initialized() @router.post("") def create_session(user_id: str | None = None) -> dict[str, Any]: """Create a new session.""" _ensure_init() store = get_session_store() if not store: raise HTTPException(status_code=503, detail="Session store not initialized") session_id = str(uuid.uuid4()) store.create(session_id, user_id) return {"session_id": session_id, "user_id": user_id} @router.post("/{session_id}/prompt") def submit_prompt(session_id: str, body: dict[str, Any]) -> dict[str, Any]: """Submit a prompt and receive FinalResponse (sync).""" _ensure_init() store = get_session_store() orch = get_orchestrator() bus = get_event_bus() if not store or not orch: raise HTTPException(status_code=503, detail="Service not initialized") sess = store.get(session_id) if not sess: raise HTTPException(status_code=404, detail="Session not found") prompt = body.get("prompt", "") parsed = parse_user_input(prompt) if not prompt or not parsed.cleaned_prompt.strip(): if parsed.intent in (UserIntent.SHOW_DISSENT, UserIntent.RERUN_RISK, UserIntent.EXPLAIN_REASONING, UserIntent.SOURCES): hist = sess.get("history", []) if hist: prompt = hist[-1].get("prompt", "") if not prompt: raise HTTPException(status_code=400, detail="No previous prompt; provide a prompt for this command") else: raise HTTPException(status_code=400, detail="prompt is required") effective_prompt = parsed.cleaned_prompt.strip() or prompt pipeline = get_safety_pipeline() if pipeline: pre_result = pipeline.pre_check(effective_prompt) if not pre_result.allowed: raise HTTPException(status_code=400, detail=pre_result.reason or "Input moderation failed") task_id = orch.submit_task(goal=effective_prompt[:200]) # Dynamic head selection head_ids = select_heads_for_complexity(effective_prompt) if parsed.intent.value == "head_strategy" and parsed.head_id: head_ids = [parsed.head_id] force_second = parsed.intent == UserIntent.RERUN_RISK return_heads = parsed.intent == UserIntent.SOURCES result = run_dvadasa( orchestrator=orch, task_id=task_id, user_prompt=effective_prompt, parsed=parsed, head_ids=head_ids if parsed.intent.value != "normal" or body.get("use_all_heads") else None, event_bus=bus, force_second_pass=force_second, return_head_outputs=return_heads, ) if return_heads and isinstance(result, tuple): final, head_outputs = result else: final = result head_outputs = [] if not final: raise HTTPException(status_code=500, detail="Failed to produce response") if pipeline: post_result = pipeline.post_check(final.final_answer) if not post_result.passed: raise HTTPException( status_code=400, detail=f"Output scan failed: {', '.join(post_result.flags)}", ) entry = { "prompt": effective_prompt, "final_answer": final.final_answer, "confidence_score": final.confidence_score, "head_contributions": final.head_contributions, } store.append_history(session_id, entry) response: dict[str, Any] = { "task_id": task_id, "final_answer": final.final_answer, "transparency_report": final.transparency_report.model_dump(), "head_contributions": final.head_contributions, "confidence_score": final.confidence_score, } if parsed.intent == UserIntent.SHOW_DISSENT: response["response_mode"] = "show_dissent" response["disputed_claims"] = final.transparency_report.agreement_map.disputed_claims elif parsed.intent == UserIntent.EXPLAIN_REASONING: response["response_mode"] = "explain" elif parsed.intent == UserIntent.SOURCES and head_outputs: response["sources"] = extract_sources_from_head_outputs(head_outputs) return response @router.websocket("/{session_id}/stream") async def stream_websocket(websocket: WebSocket, session_id: str) -> None: """WebSocket for streaming Dvādaśa response. Send {\"prompt\": \"...\"} to start.""" await websocket.accept() try: data = await websocket.receive_json() prompt = data.get("prompt", "") async def send_evt(evt: dict) -> None: await websocket.send_json(evt) await handle_stream(session_id, prompt, send_evt) except WebSocketDisconnect: pass except Exception as e: try: await websocket.send_json({"type": "error", "message": str(e)}) except Exception: pass