285 lines
7.0 KiB
Go
285 lines
7.0 KiB
Go
package orchestrator
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/explorer/virtual-banker/backend/asr"
|
|
"github.com/explorer/virtual-banker/backend/llm"
|
|
"github.com/explorer/virtual-banker/backend/rag"
|
|
"github.com/explorer/virtual-banker/backend/tools"
|
|
"github.com/explorer/virtual-banker/backend/tts"
|
|
)
|
|
|
|
// State represents the conversation state
|
|
type State string
|
|
|
|
const (
|
|
StateIdle State = "IDLE"
|
|
StateListening State = "LISTENING"
|
|
StateThinking State = "THINKING"
|
|
StateSpeaking State = "SPEAKING"
|
|
)
|
|
|
|
// Orchestrator orchestrates conversation flow
|
|
type Orchestrator struct {
|
|
sessions map[string]*SessionOrchestrator
|
|
mu sync.RWMutex
|
|
asr asr.Service
|
|
tts tts.Service
|
|
llm llm.Gateway
|
|
rag rag.Service
|
|
tools *tools.Executor
|
|
}
|
|
|
|
// NewOrchestrator creates a new orchestrator
|
|
func NewOrchestrator(asrService asr.Service, ttsService tts.Service, llmGateway llm.Gateway, ragService rag.Service, toolExecutor *tools.Executor) *Orchestrator {
|
|
return &Orchestrator{
|
|
sessions: make(map[string]*SessionOrchestrator),
|
|
asr: asrService,
|
|
tts: ttsService,
|
|
llm: llmGateway,
|
|
rag: ragService,
|
|
tools: toolExecutor,
|
|
}
|
|
}
|
|
|
|
// SessionOrchestrator manages a single session's conversation
|
|
type SessionOrchestrator struct {
|
|
sessionID string
|
|
tenantID string
|
|
userID string
|
|
state State
|
|
mu sync.RWMutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
asr asr.Service
|
|
tts tts.Service
|
|
llm llm.Gateway
|
|
rag rag.Service
|
|
tools *tools.Executor
|
|
conversation []llm.Message
|
|
}
|
|
|
|
// GetOrCreateSession gets or creates a session orchestrator
|
|
func (o *Orchestrator) GetOrCreateSession(sessionID, tenantID, userID string) *SessionOrchestrator {
|
|
o.mu.RLock()
|
|
sess, ok := o.sessions[sessionID]
|
|
o.mu.RUnlock()
|
|
|
|
if ok {
|
|
return sess
|
|
}
|
|
|
|
o.mu.Lock()
|
|
defer o.mu.Unlock()
|
|
|
|
// Double-check
|
|
if sess, ok := o.sessions[sessionID]; ok {
|
|
return sess
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
sess = &SessionOrchestrator{
|
|
sessionID: sessionID,
|
|
tenantID: tenantID,
|
|
userID: userID,
|
|
state: StateIdle,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
asr: o.asr,
|
|
tts: o.tts,
|
|
llm: o.llm,
|
|
rag: o.rag,
|
|
tools: o.tools,
|
|
conversation: []llm.Message{},
|
|
}
|
|
|
|
o.sessions[sessionID] = sess
|
|
return sess
|
|
}
|
|
|
|
// ProcessAudio processes incoming audio
|
|
func (so *SessionOrchestrator) ProcessAudio(ctx context.Context, audioData []byte) error {
|
|
so.mu.Lock()
|
|
currentState := so.state
|
|
so.mu.Unlock()
|
|
|
|
// Handle barge-in: if speaking, stop and switch to listening
|
|
if currentState == StateSpeaking {
|
|
so.StopSpeaking()
|
|
}
|
|
|
|
so.SetState(StateListening)
|
|
|
|
// Transcribe audio
|
|
transcript, err := so.asr.Transcribe(ctx, audioData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to transcribe: %w", err)
|
|
}
|
|
|
|
// Process transcript
|
|
so.SetState(StateThinking)
|
|
response, err := so.processTranscript(ctx, transcript)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process transcript: %w", err)
|
|
}
|
|
|
|
// Synthesize response
|
|
so.SetState(StateSpeaking)
|
|
return so.speak(ctx, response)
|
|
}
|
|
|
|
// ProcessText processes incoming text message
|
|
func (so *SessionOrchestrator) ProcessText(ctx context.Context, text string) error {
|
|
so.SetState(StateThinking)
|
|
|
|
// Process text
|
|
response, err := so.processTranscript(ctx, text)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process text: %w", err)
|
|
}
|
|
|
|
// Synthesize response
|
|
so.SetState(StateSpeaking)
|
|
return so.speak(ctx, response)
|
|
}
|
|
|
|
// processTranscript processes a transcript and generates a response
|
|
func (so *SessionOrchestrator) processTranscript(ctx context.Context, transcript string) (string, error) {
|
|
// Add user message to conversation
|
|
so.conversation = append(so.conversation, llm.Message{
|
|
Role: "user",
|
|
Content: transcript,
|
|
})
|
|
|
|
// Retrieve relevant documents from RAG
|
|
var retrievedDocs []rag.RetrievedDoc
|
|
if so.rag != nil {
|
|
docs, err := so.rag.Retrieve(ctx, transcript, so.tenantID, 5)
|
|
if err == nil {
|
|
retrievedDocs = docs
|
|
}
|
|
}
|
|
|
|
// Build prompt with RAG context
|
|
// Convert retrieved docs to LLM format
|
|
ragDocs := make([]llm.RetrievedDoc, len(retrievedDocs))
|
|
for i, doc := range retrievedDocs {
|
|
ragDocs[i] = llm.RetrievedDoc{
|
|
Title: doc.Title,
|
|
Content: doc.Content,
|
|
URL: doc.URL,
|
|
Score: doc.Score,
|
|
}
|
|
}
|
|
|
|
// Get available tools (would come from tenant config)
|
|
availableTools := []llm.Tool{} // TODO: Get from tenant config
|
|
|
|
// Call LLM
|
|
options := &llm.GenerateOptions{
|
|
Temperature: 0.7,
|
|
MaxTokens: 500,
|
|
Tools: availableTools,
|
|
TenantID: so.tenantID,
|
|
UserID: so.userID,
|
|
ConversationHistory: so.conversation,
|
|
}
|
|
|
|
response, err := so.llm.Generate(ctx, transcript, options)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate response: %w", err)
|
|
}
|
|
|
|
// Execute tool calls if any
|
|
if len(response.Tools) > 0 && so.tools != nil {
|
|
for _, toolCall := range response.Tools {
|
|
result, err := so.tools.Execute(ctx, toolCall.Name, toolCall.Arguments, so.userID, so.tenantID)
|
|
if err != nil {
|
|
// Log error but continue
|
|
fmt.Printf("Tool execution error: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
// Add tool result to conversation
|
|
if result.Success {
|
|
so.conversation = append(so.conversation, llm.Message{
|
|
Role: "assistant",
|
|
Content: fmt.Sprintf("Tool %s executed successfully: %v", toolCall.Name, result.Data),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add assistant response to conversation
|
|
so.conversation = append(so.conversation, llm.Message{
|
|
Role: "assistant",
|
|
Content: response.Text,
|
|
})
|
|
|
|
return response.Text, nil
|
|
}
|
|
|
|
// speak synthesizes and plays audio
|
|
func (so *SessionOrchestrator) speak(ctx context.Context, text string) error {
|
|
// Synthesize audio
|
|
audioData, err := so.tts.Synthesize(ctx, text)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to synthesize: %w", err)
|
|
}
|
|
|
|
// Get visemes for avatar
|
|
visemes, err := so.tts.GetVisemes(ctx, text)
|
|
if err != nil {
|
|
// Log error but continue
|
|
fmt.Printf("Failed to get visemes: %v\n", err)
|
|
}
|
|
|
|
// TODO: Send audio and visemes to client via WebRTC/WebSocket
|
|
_ = audioData
|
|
_ = visemes
|
|
|
|
// Simulate speaking duration
|
|
time.Sleep(time.Duration(len(text)*50) * time.Millisecond)
|
|
|
|
so.SetState(StateIdle)
|
|
return nil
|
|
}
|
|
|
|
// StopSpeaking stops current speech (barge-in)
|
|
func (so *SessionOrchestrator) StopSpeaking() {
|
|
so.mu.Lock()
|
|
defer so.mu.Unlock()
|
|
|
|
if so.state == StateSpeaking {
|
|
// Cancel current TTS synthesis
|
|
so.cancel()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
so.ctx = ctx
|
|
so.cancel = cancel
|
|
so.state = StateIdle
|
|
}
|
|
}
|
|
|
|
// SetState sets the conversation state
|
|
func (so *SessionOrchestrator) SetState(state State) {
|
|
so.mu.Lock()
|
|
defer so.mu.Unlock()
|
|
so.state = state
|
|
}
|
|
|
|
// GetState gets the current conversation state
|
|
func (so *SessionOrchestrator) GetState() State {
|
|
so.mu.RLock()
|
|
defer so.mu.RUnlock()
|
|
return so.state
|
|
}
|
|
|
|
// Close closes the session orchestrator
|
|
func (so *SessionOrchestrator) Close() {
|
|
so.cancel()
|
|
}
|