Files
virtual-banker/backend/orchestrator/orchestrator.go

285 lines
7.0 KiB
Go

package orchestrator
import (
"context"
"fmt"
"sync"
"time"
"github.com/explorer/virtual-banker/backend/asr"
"github.com/explorer/virtual-banker/backend/llm"
"github.com/explorer/virtual-banker/backend/rag"
"github.com/explorer/virtual-banker/backend/tools"
"github.com/explorer/virtual-banker/backend/tts"
)
// State represents the conversation state
type State string
const (
StateIdle State = "IDLE"
StateListening State = "LISTENING"
StateThinking State = "THINKING"
StateSpeaking State = "SPEAKING"
)
// Orchestrator orchestrates conversation flow
type Orchestrator struct {
sessions map[string]*SessionOrchestrator
mu sync.RWMutex
asr asr.Service
tts tts.Service
llm llm.Gateway
rag rag.Service
tools *tools.Executor
}
// NewOrchestrator creates a new orchestrator
func NewOrchestrator(asrService asr.Service, ttsService tts.Service, llmGateway llm.Gateway, ragService rag.Service, toolExecutor *tools.Executor) *Orchestrator {
return &Orchestrator{
sessions: make(map[string]*SessionOrchestrator),
asr: asrService,
tts: ttsService,
llm: llmGateway,
rag: ragService,
tools: toolExecutor,
}
}
// SessionOrchestrator manages a single session's conversation
type SessionOrchestrator struct {
sessionID string
tenantID string
userID string
state State
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
asr asr.Service
tts tts.Service
llm llm.Gateway
rag rag.Service
tools *tools.Executor
conversation []llm.Message
}
// GetOrCreateSession gets or creates a session orchestrator
func (o *Orchestrator) GetOrCreateSession(sessionID, tenantID, userID string) *SessionOrchestrator {
o.mu.RLock()
sess, ok := o.sessions[sessionID]
o.mu.RUnlock()
if ok {
return sess
}
o.mu.Lock()
defer o.mu.Unlock()
// Double-check
if sess, ok := o.sessions[sessionID]; ok {
return sess
}
ctx, cancel := context.WithCancel(context.Background())
sess = &SessionOrchestrator{
sessionID: sessionID,
tenantID: tenantID,
userID: userID,
state: StateIdle,
ctx: ctx,
cancel: cancel,
asr: o.asr,
tts: o.tts,
llm: o.llm,
rag: o.rag,
tools: o.tools,
conversation: []llm.Message{},
}
o.sessions[sessionID] = sess
return sess
}
// ProcessAudio processes incoming audio
func (so *SessionOrchestrator) ProcessAudio(ctx context.Context, audioData []byte) error {
so.mu.Lock()
currentState := so.state
so.mu.Unlock()
// Handle barge-in: if speaking, stop and switch to listening
if currentState == StateSpeaking {
so.StopSpeaking()
}
so.SetState(StateListening)
// Transcribe audio
transcript, err := so.asr.Transcribe(ctx, audioData)
if err != nil {
return fmt.Errorf("failed to transcribe: %w", err)
}
// Process transcript
so.SetState(StateThinking)
response, err := so.processTranscript(ctx, transcript)
if err != nil {
return fmt.Errorf("failed to process transcript: %w", err)
}
// Synthesize response
so.SetState(StateSpeaking)
return so.speak(ctx, response)
}
// ProcessText processes incoming text message
func (so *SessionOrchestrator) ProcessText(ctx context.Context, text string) error {
so.SetState(StateThinking)
// Process text
response, err := so.processTranscript(ctx, text)
if err != nil {
return fmt.Errorf("failed to process text: %w", err)
}
// Synthesize response
so.SetState(StateSpeaking)
return so.speak(ctx, response)
}
// processTranscript processes a transcript and generates a response
func (so *SessionOrchestrator) processTranscript(ctx context.Context, transcript string) (string, error) {
// Add user message to conversation
so.conversation = append(so.conversation, llm.Message{
Role: "user",
Content: transcript,
})
// Retrieve relevant documents from RAG
var retrievedDocs []rag.RetrievedDoc
if so.rag != nil {
docs, err := so.rag.Retrieve(ctx, transcript, so.tenantID, 5)
if err == nil {
retrievedDocs = docs
}
}
// Build prompt with RAG context
// Convert retrieved docs to LLM format
ragDocs := make([]llm.RetrievedDoc, len(retrievedDocs))
for i, doc := range retrievedDocs {
ragDocs[i] = llm.RetrievedDoc{
Title: doc.Title,
Content: doc.Content,
URL: doc.URL,
Score: doc.Score,
}
}
// Get available tools (would come from tenant config)
availableTools := []llm.Tool{} // TODO: Get from tenant config
// Call LLM
options := &llm.GenerateOptions{
Temperature: 0.7,
MaxTokens: 500,
Tools: availableTools,
TenantID: so.tenantID,
UserID: so.userID,
ConversationHistory: so.conversation,
}
response, err := so.llm.Generate(ctx, transcript, options)
if err != nil {
return "", fmt.Errorf("failed to generate response: %w", err)
}
// Execute tool calls if any
if len(response.Tools) > 0 && so.tools != nil {
for _, toolCall := range response.Tools {
result, err := so.tools.Execute(ctx, toolCall.Name, toolCall.Arguments, so.userID, so.tenantID)
if err != nil {
// Log error but continue
fmt.Printf("Tool execution error: %v\n", err)
continue
}
// Add tool result to conversation
if result.Success {
so.conversation = append(so.conversation, llm.Message{
Role: "assistant",
Content: fmt.Sprintf("Tool %s executed successfully: %v", toolCall.Name, result.Data),
})
}
}
}
// Add assistant response to conversation
so.conversation = append(so.conversation, llm.Message{
Role: "assistant",
Content: response.Text,
})
return response.Text, nil
}
// speak synthesizes and plays audio
func (so *SessionOrchestrator) speak(ctx context.Context, text string) error {
// Synthesize audio
audioData, err := so.tts.Synthesize(ctx, text)
if err != nil {
return fmt.Errorf("failed to synthesize: %w", err)
}
// Get visemes for avatar
visemes, err := so.tts.GetVisemes(ctx, text)
if err != nil {
// Log error but continue
fmt.Printf("Failed to get visemes: %v\n", err)
}
// TODO: Send audio and visemes to client via WebRTC/WebSocket
_ = audioData
_ = visemes
// Simulate speaking duration
time.Sleep(time.Duration(len(text)*50) * time.Millisecond)
so.SetState(StateIdle)
return nil
}
// StopSpeaking stops current speech (barge-in)
func (so *SessionOrchestrator) StopSpeaking() {
so.mu.Lock()
defer so.mu.Unlock()
if so.state == StateSpeaking {
// Cancel current TTS synthesis
so.cancel()
ctx, cancel := context.WithCancel(context.Background())
so.ctx = ctx
so.cancel = cancel
so.state = StateIdle
}
}
// SetState sets the conversation state
func (so *SessionOrchestrator) SetState(state State) {
so.mu.Lock()
defer so.mu.Unlock()
so.state = state
}
// GetState gets the current conversation state
func (so *SessionOrchestrator) GetState() State {
so.mu.RLock()
defer so.mu.RUnlock()
return so.state
}
// Close closes the session orchestrator
func (so *SessionOrchestrator) Close() {
so.cancel()
}