Add full monorepo: virtual-banker, backend, frontend, docs, scripts, deployment

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
defiQUG
2026-02-10 11:32:49 -08:00
commit b4753cef7e
81 changed files with 9255 additions and 0 deletions

35
backend/api/realtime.go Normal file
View File

@@ -0,0 +1,35 @@
package api
import (
"net/http"
"github.com/gorilla/mux"
)
// HandleRealtimeWebSocket handles WebSocket upgrade for realtime communication
func (s *Server) HandleRealtimeWebSocket(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
sessionID := vars["id"]
if sessionID == "" {
writeError(w, http.StatusBadRequest, "session_id is required", nil)
return
}
// Get session to validate
_, err := s.sessionManager.GetSession(r.Context(), sessionID)
if err != nil {
writeError(w, http.StatusUnauthorized, "invalid session", err)
return
}
// Upgrade to WebSocket
if s.realtimeGateway != nil {
if err := s.realtimeGateway.HandleWebSocket(w, r, sessionID); err != nil {
writeError(w, http.StatusInternalServerError, "failed to upgrade connection", err)
return
}
return
}
writeError(w, http.StatusServiceUnavailable, "realtime gateway not available", nil)
}

185
backend/api/routes.go Normal file
View File

@@ -0,0 +1,185 @@
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/explorer/virtual-banker/backend/realtime"
"github.com/explorer/virtual-banker/backend/session"
"github.com/gorilla/mux"
)
// Server handles HTTP requests
type Server struct {
sessionManager *session.Manager
realtimeGateway *realtime.Gateway
router *mux.Router
}
// NewServer creates a new API server
func NewServer(sessionManager *session.Manager, realtimeGateway *realtime.Gateway) *Server {
s := &Server{
sessionManager: sessionManager,
realtimeGateway: realtimeGateway,
router: mux.NewRouter(),
}
s.setupRoutes()
return s
}
// setupRoutes sets up all API routes
func (s *Server) setupRoutes() {
api := s.router.PathPrefix("/v1").Subrouter()
// Session routes
api.HandleFunc("/sessions", s.handleCreateSession).Methods("POST")
api.HandleFunc("/sessions/{id}/refresh-token", s.handleRefreshToken).Methods("POST")
api.HandleFunc("/sessions/{id}/end", s.handleEndSession).Methods("POST")
// Realtime WebSocket
api.HandleFunc("/realtime/{id}", s.HandleRealtimeWebSocket)
// Health check
s.router.HandleFunc("/health", s.handleHealth).Methods("GET")
}
// CreateSessionRequest represents a session creation request
type CreateSessionRequest struct {
TenantID string `json:"tenant_id"`
UserID string `json:"user_id"`
AuthAssertion string `json:"auth_assertion"`
PortalContext map[string]interface{} `json:"portal_context,omitempty"`
}
// CreateSessionResponse represents a session creation response
type CreateSessionResponse struct {
SessionID string `json:"session_id"`
EphemeralToken string `json:"ephemeral_token"`
Config *session.TenantConfig `json:"config"`
ExpiresAt time.Time `json:"expires_at"`
}
// handleCreateSession handles POST /v1/sessions
func (s *Server) handleCreateSession(w http.ResponseWriter, r *http.Request) {
var req CreateSessionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body", err)
return
}
if req.TenantID == "" || req.UserID == "" || req.AuthAssertion == "" {
writeError(w, http.StatusBadRequest, "tenant_id, user_id, and auth_assertion are required", nil)
return
}
sess, err := s.sessionManager.CreateSession(r.Context(), req.TenantID, req.UserID, req.AuthAssertion)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to create session", err)
return
}
resp := CreateSessionResponse{
SessionID: sess.ID,
EphemeralToken: sess.EphemeralToken,
Config: sess.Config,
ExpiresAt: sess.ExpiresAt,
}
writeJSON(w, http.StatusCreated, resp)
}
// RefreshTokenResponse represents a token refresh response
type RefreshTokenResponse struct {
EphemeralToken string `json:"ephemeral_token"`
ExpiresAt time.Time `json:"expires_at"`
}
// handleRefreshToken handles POST /v1/sessions/:id/refresh-token
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
sessionID := vars["id"]
if sessionID == "" {
writeError(w, http.StatusBadRequest, "session_id is required", nil)
return
}
newToken, err := s.sessionManager.RefreshToken(r.Context(), sessionID)
if err != nil {
if err.Error() == "session expired" {
writeError(w, http.StatusUnauthorized, "session expired", err)
return
}
writeError(w, http.StatusInternalServerError, "failed to refresh token", err)
return
}
sess, err := s.sessionManager.GetSession(r.Context(), sessionID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to get session", err)
return
}
resp := RefreshTokenResponse{
EphemeralToken: newToken,
ExpiresAt: sess.ExpiresAt,
}
writeJSON(w, http.StatusOK, resp)
}
// handleEndSession handles POST /v1/sessions/:id/end
func (s *Server) handleEndSession(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
sessionID := vars["id"]
if sessionID == "" {
writeError(w, http.StatusBadRequest, "session_id is required", nil)
return
}
if err := s.sessionManager.EndSession(r.Context(), sessionID); err != nil {
writeError(w, http.StatusInternalServerError, "failed to end session", err)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ended"})
}
// handleHealth handles GET /health
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "healthy"})
}
// ServeHTTP implements http.Handler
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// writeJSON writes a JSON response
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
// ErrorResponse represents an error response
type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message,omitempty"`
}
// writeError writes an error response
func writeError(w http.ResponseWriter, status int, message string, err error) {
resp := ErrorResponse{
Error: message,
Message: func() string {
if err != nil {
return err.Error()
}
return ""
}(),
}
writeJSON(w, status, resp)
}

102
backend/asr/service.go Normal file
View File

@@ -0,0 +1,102 @@
package asr
import (
"context"
"fmt"
"io"
"time"
)
// Service provides speech-to-text functionality
type Service interface {
TranscribeStream(ctx context.Context, audioStream io.Reader) (<-chan TranscriptEvent, error)
Transcribe(ctx context.Context, audioData []byte) (string, error)
}
// TranscriptEvent represents a transcription event
type TranscriptEvent struct {
Type string `json:"type"` // "partial" or "final"
Text string `json:"text"`
Confidence float64 `json:"confidence,omitempty"`
Timestamp int64 `json:"timestamp"`
Words []Word `json:"words,omitempty"`
}
// Word represents a word with timing information
type Word struct {
Word string `json:"word"`
StartTime float64 `json:"start_time"`
EndTime float64 `json:"end_time"`
Confidence float64 `json:"confidence,omitempty"`
}
// MockASRService is a mock implementation for development
type MockASRService struct{}
// NewMockASRService creates a new mock ASR service
func NewMockASRService() *MockASRService {
return &MockASRService{}
}
// TranscribeStream transcribes an audio stream
func (s *MockASRService) TranscribeStream(ctx context.Context, audioStream io.Reader) (<-chan TranscriptEvent, error) {
events := make(chan TranscriptEvent, 10)
go func() {
defer close(events)
// Mock implementation - in production, integrate with Deepgram, Google STT, etc.
// For now, just send a mock event
select {
case <-ctx.Done():
return
case events <- TranscriptEvent{
Type: "final",
Text: "Hello, how can I help you today?",
Confidence: 0.95,
Timestamp: time.Now().Unix(),
}:
}
}()
return events, nil
}
// Transcribe transcribes audio data
func (s *MockASRService) Transcribe(ctx context.Context, audioData []byte) (string, error) {
// Mock implementation
return "Hello, how can I help you today?", nil
}
// DeepgramASRService integrates with Deepgram (example - requires API key)
type DeepgramASRService struct {
apiKey string
}
// NewDeepgramASRService creates a new Deepgram ASR service
func NewDeepgramASRService(apiKey string) *DeepgramASRService {
return &DeepgramASRService{
apiKey: apiKey,
}
}
// TranscribeStream transcribes using Deepgram streaming API
func (s *DeepgramASRService) TranscribeStream(ctx context.Context, audioStream io.Reader) (<-chan TranscriptEvent, error) {
events := make(chan TranscriptEvent, 10)
// TODO: Implement Deepgram streaming API integration
// This would involve:
// 1. Establishing WebSocket connection to Deepgram
// 2. Sending audio chunks
// 3. Receiving partial and final transcripts
// 4. Converting to TranscriptEvent format
return events, fmt.Errorf("not implemented - requires Deepgram API integration")
}
// Transcribe transcribes using Deepgram REST API
func (s *DeepgramASRService) Transcribe(ctx context.Context, audioData []byte) (string, error) {
// TODO: Implement Deepgram REST API integration
return "", fmt.Errorf("not implemented - requires Deepgram API integration")
}

22
backend/go.mod Normal file
View File

@@ -0,0 +1,22 @@
module github.com/explorer/virtual-banker/backend
go 1.21
require (
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.1
github.com/jackc/pgx/v5 v5.5.1
github.com/redis/go-redis/v9 v9.3.0
)
require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/text v0.14.0 // indirect
)

44
backend/go.sum Normal file
View File

@@ -0,0 +1,44 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

102
backend/llm/gateway.go Normal file
View File

@@ -0,0 +1,102 @@
package llm
import (
"context"
"fmt"
)
// Gateway provides LLM functionality
type Gateway interface {
Generate(ctx context.Context, prompt string, options *GenerateOptions) (*GenerateResponse, error)
}
// GenerateOptions contains options for generation
type GenerateOptions struct {
Temperature float64
MaxTokens int
Tools []Tool
TenantID string
UserID string
ConversationHistory []Message
}
// Tool represents a callable tool/function
type Tool struct {
Name string
Description string
Parameters map[string]interface{}
}
// Message represents a conversation message
type Message struct {
Role string // "user" or "assistant"
Content string
}
// GenerateResponse contains the LLM response
type GenerateResponse struct {
Text string
Tools []ToolCall
Emotion *Emotion
Gestures []string
}
// ToolCall represents a tool call request
type ToolCall struct {
Name string
Arguments map[string]interface{}
}
// Emotion represents emotional state for avatar
type Emotion struct {
Valence float64 // -1.0 to 1.0
Arousal float64 // 0.0 to 1.0
}
// MockLLMGateway is a mock implementation for development
type MockLLMGateway struct{}
// NewMockLLMGateway creates a new mock LLM gateway
func NewMockLLMGateway() *MockLLMGateway {
return &MockLLMGateway{}
}
// Generate generates a response using mock LLM
func (g *MockLLMGateway) Generate(ctx context.Context, prompt string, options *GenerateOptions) (*GenerateResponse, error) {
// Mock implementation
return &GenerateResponse{
Text: "I understand. How can I assist you with your banking needs today?",
Emotion: &Emotion{
Valence: 0.5,
Arousal: 0.3,
},
Gestures: []string{"nod"},
}, nil
}
// OpenAIGateway integrates with OpenAI (example - requires API key)
type OpenAIGateway struct {
apiKey string
model string
}
// NewOpenAIGateway creates a new OpenAI gateway
func NewOpenAIGateway(apiKey, model string) *OpenAIGateway {
return &OpenAIGateway{
apiKey: apiKey,
model: model,
}
}
// Generate generates using OpenAI API
func (g *OpenAIGateway) Generate(ctx context.Context, prompt string, options *GenerateOptions) (*GenerateResponse, error) {
// TODO: Implement OpenAI API integration
// This would involve:
// 1. Building the prompt with system message, conversation history
// 2. Adding tool definitions if tools are provided
// 3. Making API call to OpenAI
// 4. Parsing response and extracting tool calls
// 5. Mapping to GenerateResponse format
return nil, fmt.Errorf("not implemented - requires OpenAI API integration")
}

124
backend/llm/prompt.go Normal file
View File

@@ -0,0 +1,124 @@
package llm
import (
"fmt"
"strings"
)
// BuildPrompt builds a prompt from components
func BuildPrompt(tenantConfig *TenantConfig, conversationHistory []Message, userInput string, retrievedDocs []RetrievedDoc) string {
var parts []string
// System message
systemMsg := buildSystemMessage(tenantConfig)
parts = append(parts, systemMsg)
// Retrieved documents (RAG context)
if len(retrievedDocs) > 0 {
parts = append(parts, "\n## Context:")
for i, doc := range retrievedDocs {
parts = append(parts, fmt.Sprintf("\n[Document %d]", i+1))
parts = append(parts, fmt.Sprintf("Title: %s", doc.Title))
parts = append(parts, fmt.Sprintf("Content: %s", doc.Content))
if doc.URL != "" {
parts = append(parts, fmt.Sprintf("Source: %s", doc.URL))
}
}
}
// Conversation history
if len(conversationHistory) > 0 {
parts = append(parts, "\n## Conversation History:")
for _, msg := range conversationHistory {
parts = append(parts, fmt.Sprintf("%s: %s", strings.Title(msg.Role), msg.Content))
}
}
// Current user input
parts = append(parts, fmt.Sprintf("\n## User: %s", userInput))
parts = append(parts, "\n## Assistant:")
return strings.Join(parts, "\n")
}
// TenantConfig holds tenant-specific configuration
type TenantConfig struct {
Greeting string
Tone string // "professional", "friendly", "formal"
Disclaimers []string
AllowedTools []string
}
// RetrievedDoc represents a retrieved document from RAG
type RetrievedDoc struct {
Title string
Content string
URL string
Score float64
}
// BuildPromptWithRAG builds a prompt with RAG context
func BuildPromptWithRAG(tenantConfig *TenantConfig, conversationHistory []Message, userInput string, retrievedDocs []RetrievedDoc) string {
var parts []string
// System message
systemMsg := buildSystemMessage(tenantConfig)
parts = append(parts, systemMsg)
// Retrieved documents (RAG context)
if len(retrievedDocs) > 0 {
parts = append(parts, "\n## Context:")
for i, doc := range retrievedDocs {
parts = append(parts, fmt.Sprintf("\n[Document %d]", i+1))
parts = append(parts, fmt.Sprintf("Title: %s", doc.Title))
parts = append(parts, fmt.Sprintf("Content: %s", doc.Content))
if doc.URL != "" {
parts = append(parts, fmt.Sprintf("Source: %s", doc.URL))
}
}
}
// Conversation history
if len(conversationHistory) > 0 {
parts = append(parts, "\n## Conversation History:")
for _, msg := range conversationHistory {
parts = append(parts, fmt.Sprintf("%s: %s", strings.Title(msg.Role), msg.Content))
}
}
// Current user input
parts = append(parts, fmt.Sprintf("\n## User: %s", userInput))
parts = append(parts, "\n## Assistant:")
return strings.Join(parts, "\n")
}
// buildSystemMessage builds the system message
func buildSystemMessage(config *TenantConfig) string {
var parts []string
parts = append(parts, "You are a helpful Virtual Banker assistant.")
if config.Tone != "" {
parts = append(parts, fmt.Sprintf("Your tone should be %s.", config.Tone))
}
if len(config.Disclaimers) > 0 {
parts = append(parts, "\nImportant disclaimers:")
for _, disclaimer := range config.Disclaimers {
parts = append(parts, fmt.Sprintf("- %s", disclaimer))
}
}
if len(config.AllowedTools) > 0 {
parts = append(parts, "\nYou have access to the following tools:")
for _, tool := range config.AllowedTools {
parts = append(parts, fmt.Sprintf("- %s", tool))
}
}
parts = append(parts, "\nAlways be helpful, accurate, and respectful.")
parts = append(parts, "If you don't know something, say so and offer to help find the answer.")
return strings.Join(parts, "\n")
}

136
backend/main.go Normal file
View File

@@ -0,0 +1,136 @@
package main
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/explorer/virtual-banker/backend/api"
"github.com/explorer/virtual-banker/backend/asr"
"github.com/explorer/virtual-banker/backend/llm"
"github.com/explorer/virtual-banker/backend/orchestrator"
"github.com/explorer/virtual-banker/backend/rag"
"github.com/explorer/virtual-banker/backend/realtime"
"github.com/explorer/virtual-banker/backend/session"
"github.com/explorer/virtual-banker/backend/tools"
"github.com/explorer/virtual-banker/backend/tools/banking"
"github.com/explorer/virtual-banker/backend/tts"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
)
func main() {
// Load configuration from environment
dbURL := getEnv("DATABASE_URL", "postgres://explorer:changeme@localhost:5432/explorer?sslmode=disable")
redisURL := getEnv("REDIS_URL", "redis://localhost:6379")
port := getEnv("PORT", "8081")
// Initialize database connection
db, err := pgxpool.New(context.Background(), dbURL)
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
defer db.Close()
// Initialize Redis connection
opt, err := redis.ParseURL(redisURL)
if err != nil {
log.Fatalf("Failed to parse Redis URL: %v", err)
}
redisClient := redis.NewClient(opt)
defer redisClient.Close()
// Test connections
if err := db.Ping(context.Background()); err != nil {
log.Fatalf("Database ping failed: %v", err)
}
if err := redisClient.Ping(context.Background()).Err(); err != nil {
log.Fatalf("Redis ping failed: %v", err)
}
// Initialize services
sessionManager := session.NewManager(db, redisClient)
// Initialize ASR/TTS (using mocks for now)
asrService := asr.NewMockASRService()
ttsService := tts.NewMockTTSService()
// Initialize LLM (using mock for now)
llmGateway := llm.NewMockLLMGateway()
// Initialize RAG
ragService := rag.NewRAGService(db)
// Initialize tools
toolRegistry := tools.NewRegistry()
toolRegistry.Register(banking.NewAccountStatusTool())
toolRegistry.Register(banking.NewCreateTicketTool())
toolRegistry.Register(banking.NewScheduleAppointmentTool())
toolRegistry.Register(banking.NewSubmitPaymentTool())
auditLogger := &tools.MockAuditLogger{}
toolExecutor := tools.NewExecutor(toolRegistry, auditLogger)
// Initialize orchestrator
convOrchestrator := orchestrator.NewOrchestrator(
asrService,
ttsService,
llmGateway,
ragService,
toolExecutor,
)
// Initialize realtime gateway
realtimeGateway := realtime.NewGateway()
// Initialize API server
apiServer := api.NewServer(sessionManager, realtimeGateway)
// Store orchestrator reference (would be used by handlers)
_ = convOrchestrator
// Create HTTP server
srv := &http.Server{
Addr: ":" + port,
Handler: apiServer,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
// Start server in goroutine
go func() {
log.Printf("Virtual Banker API server starting on port %s", port)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed to start: %v", err)
}
}()
// Wait for interrupt signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("Shutting down server...")
// Graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
}
log.Println("Server exited")
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}

163
backend/memory/service.go Normal file
View File

@@ -0,0 +1,163 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
)
// Service manages user memory and preferences
type Service interface {
GetProfile(ctx context.Context, userID, tenantID string) (*UserProfile, error)
SaveProfile(ctx context.Context, profile *UserProfile) error
GetHistory(ctx context.Context, userID, tenantID string, limit int) ([]ConversationHistory, error)
SaveHistory(ctx context.Context, history *ConversationHistory) error
}
// UserProfile represents user preferences and memory
type UserProfile struct {
UserID string
TenantID string
Preferences map[string]interface{}
Context map[string]interface{}
CreatedAt string
UpdatedAt string
}
// ConversationHistory represents a conversation history entry
type ConversationHistory struct {
ID string
UserID string
TenantID string
SessionID string
Messages []Message
CreatedAt string
}
// Message represents a message in history
type Message struct {
Role string
Content string
Timestamp string
}
// MemoryService implements memory using PostgreSQL
type MemoryService struct {
db *pgxpool.Pool
}
// NewMemoryService creates a new memory service
func NewMemoryService(db *pgxpool.Pool) *MemoryService {
return &MemoryService{
db: db,
}
}
// GetProfile gets user profile
func (s *MemoryService) GetProfile(ctx context.Context, userID, tenantID string) (*UserProfile, error) {
query := `
SELECT user_id, tenant_id, preferences, context, created_at, updated_at
FROM user_profiles
WHERE user_id = $1 AND tenant_id = $2
`
var profile UserProfile
var prefsJSON, contextJSON []byte
err := s.db.QueryRow(ctx, query, userID, tenantID).Scan(
&profile.UserID,
&profile.TenantID,
&prefsJSON,
&contextJSON,
&profile.CreatedAt,
&profile.UpdatedAt,
)
if err != nil {
// Return default profile if not found
return &UserProfile{
UserID: userID,
TenantID: tenantID,
Preferences: make(map[string]interface{}),
Context: make(map[string]interface{}),
}, nil
}
if err := json.Unmarshal(prefsJSON, &profile.Preferences); err != nil {
profile.Preferences = make(map[string]interface{})
}
if err := json.Unmarshal(contextJSON, &profile.Context); err != nil {
profile.Context = make(map[string]interface{})
}
return &profile, nil
}
// SaveProfile saves user profile
func (s *MemoryService) SaveProfile(ctx context.Context, profile *UserProfile) error {
prefsJSON, _ := json.Marshal(profile.Preferences)
contextJSON, _ := json.Marshal(profile.Context)
query := `
INSERT INTO user_profiles (user_id, tenant_id, preferences, context, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (user_id, tenant_id) DO UPDATE SET
preferences = $3,
context = $4,
updated_at = NOW()
`
_, err := s.db.Exec(ctx, query, profile.UserID, profile.TenantID, prefsJSON, contextJSON)
return err
}
// GetHistory gets conversation history
func (s *MemoryService) GetHistory(ctx context.Context, userID, tenantID string, limit int) ([]ConversationHistory, error) {
if limit <= 0 {
limit = 10
}
query := `
SELECT id, user_id, tenant_id, session_id, messages, created_at
FROM conversation_history
WHERE user_id = $1 AND tenant_id = $2
ORDER BY created_at DESC
LIMIT $3
`
rows, err := s.db.Query(ctx, query, userID, tenantID, limit)
if err != nil {
return nil, fmt.Errorf("failed to query: %w", err)
}
defer rows.Close()
var histories []ConversationHistory
for rows.Next() {
var history ConversationHistory
var messagesJSON []byte
if err := rows.Scan(&history.ID, &history.UserID, &history.TenantID, &history.SessionID, &messagesJSON, &history.CreatedAt); err != nil {
continue
}
if err := json.Unmarshal(messagesJSON, &history.Messages); err != nil {
history.Messages = []Message{}
}
histories = append(histories, history)
}
return histories, nil
}
// SaveHistory saves conversation history
func (s *MemoryService) SaveHistory(ctx context.Context, history *ConversationHistory) error {
messagesJSON, _ := json.Marshal(history.Messages)
query := `
INSERT INTO conversation_history (id, user_id, tenant_id, session_id, messages, created_at)
VALUES ($1, $2, $3, $4, $5, NOW())
`
_, err := s.db.Exec(ctx, query, history.ID, history.UserID, history.TenantID, history.SessionID, messagesJSON)
return err
}

View File

@@ -0,0 +1,73 @@
package observability
import (
"sync/atomic"
"time"
)
// Metrics collects system metrics
type Metrics struct {
SessionCreations int64
ActiveSessions int64
MessagesProcessed int64
ASRLatency int64 // microseconds
TTSLatency int64 // microseconds
LLMLatency int64 // microseconds
Errors int64
}
var globalMetrics = &Metrics{}
// GetMetrics returns current metrics
func GetMetrics() *Metrics {
return &Metrics{
SessionCreations: atomic.LoadInt64(&globalMetrics.SessionCreations),
ActiveSessions: atomic.LoadInt64(&globalMetrics.ActiveSessions),
MessagesProcessed: atomic.LoadInt64(&globalMetrics.MessagesProcessed),
ASRLatency: atomic.LoadInt64(&globalMetrics.ASRLatency),
TTSLatency: atomic.LoadInt64(&globalMetrics.TTSLatency),
LLMLatency: atomic.LoadInt64(&globalMetrics.LLMLatency),
Errors: atomic.LoadInt64(&globalMetrics.Errors),
}
}
// IncrementSessionCreations increments session creation count
func IncrementSessionCreations() {
atomic.AddInt64(&globalMetrics.SessionCreations, 1)
}
// IncrementActiveSessions increments active session count
func IncrementActiveSessions() {
atomic.AddInt64(&globalMetrics.ActiveSessions, 1)
}
// DecrementActiveSessions decrements active session count
func DecrementActiveSessions() {
atomic.AddInt64(&globalMetrics.ActiveSessions, -1)
}
// IncrementMessagesProcessed increments message count
func IncrementMessagesProcessed() {
atomic.AddInt64(&globalMetrics.MessagesProcessed, 1)
}
// RecordASRLatency records ASR latency
func RecordASRLatency(duration time.Duration) {
atomic.StoreInt64(&globalMetrics.ASRLatency, duration.Microseconds())
}
// RecordTTSLatency records TTS latency
func RecordTTSLatency(duration time.Duration) {
atomic.StoreInt64(&globalMetrics.TTSLatency, duration.Microseconds())
}
// RecordLLMLatency records LLM latency
func RecordLLMLatency(duration time.Duration) {
atomic.StoreInt64(&globalMetrics.LLMLatency, duration.Microseconds())
}
// IncrementErrors increments error count
func IncrementErrors() {
atomic.AddInt64(&globalMetrics.Errors, 1)
}

View File

@@ -0,0 +1,48 @@
package observability
import (
"context"
"fmt"
)
// Tracer provides distributed tracing
type Tracer interface {
StartSpan(ctx context.Context, name string) (context.Context, Span)
}
// Span represents a tracing span
type Span interface {
End()
SetAttribute(key string, value interface{})
SetError(err error)
}
// MockTracer is a mock tracer for development
type MockTracer struct{}
// StartSpan starts a new span
func (t *MockTracer) StartSpan(ctx context.Context, name string) (context.Context, Span) {
return ctx, &MockSpan{}
}
// MockSpan is a mock span
type MockSpan struct{}
// End ends the span
func (m *MockSpan) End() {}
// SetAttribute sets an attribute
func (m *MockSpan) SetAttribute(key string, value interface{}) {}
// SetError sets an error
func (m *MockSpan) SetError(err error) {}
// TraceConversation traces a conversation turn
func TraceConversation(ctx context.Context, tracer Tracer, sessionID, userID string, input string) (context.Context, Span) {
ctx, span := tracer.StartSpan(ctx, "conversation.turn")
span.SetAttribute("session_id", sessionID)
span.SetAttribute("user_id", userID)
span.SetAttribute("input_length", len(input))
return ctx, span
}

View File

@@ -0,0 +1,284 @@
package orchestrator
import (
"context"
"fmt"
"sync"
"time"
"github.com/explorer/virtual-banker/backend/asr"
"github.com/explorer/virtual-banker/backend/llm"
"github.com/explorer/virtual-banker/backend/rag"
"github.com/explorer/virtual-banker/backend/tools"
"github.com/explorer/virtual-banker/backend/tts"
)
// State represents the conversation state
type State string
const (
StateIdle State = "IDLE"
StateListening State = "LISTENING"
StateThinking State = "THINKING"
StateSpeaking State = "SPEAKING"
)
// Orchestrator orchestrates conversation flow
type Orchestrator struct {
sessions map[string]*SessionOrchestrator
mu sync.RWMutex
asr asr.Service
tts tts.Service
llm llm.Gateway
rag rag.Service
tools *tools.Executor
}
// NewOrchestrator creates a new orchestrator
func NewOrchestrator(asrService asr.Service, ttsService tts.Service, llmGateway llm.Gateway, ragService rag.Service, toolExecutor *tools.Executor) *Orchestrator {
return &Orchestrator{
sessions: make(map[string]*SessionOrchestrator),
asr: asrService,
tts: ttsService,
llm: llmGateway,
rag: ragService,
tools: toolExecutor,
}
}
// SessionOrchestrator manages a single session's conversation
type SessionOrchestrator struct {
sessionID string
tenantID string
userID string
state State
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
asr asr.Service
tts tts.Service
llm llm.Gateway
rag rag.Service
tools *tools.Executor
conversation []llm.Message
}
// GetOrCreateSession gets or creates a session orchestrator
func (o *Orchestrator) GetOrCreateSession(sessionID, tenantID, userID string) *SessionOrchestrator {
o.mu.RLock()
sess, ok := o.sessions[sessionID]
o.mu.RUnlock()
if ok {
return sess
}
o.mu.Lock()
defer o.mu.Unlock()
// Double-check
if sess, ok := o.sessions[sessionID]; ok {
return sess
}
ctx, cancel := context.WithCancel(context.Background())
sess = &SessionOrchestrator{
sessionID: sessionID,
tenantID: tenantID,
userID: userID,
state: StateIdle,
ctx: ctx,
cancel: cancel,
asr: o.asr,
tts: o.tts,
llm: o.llm,
rag: o.rag,
tools: o.tools,
conversation: []llm.Message{},
}
o.sessions[sessionID] = sess
return sess
}
// ProcessAudio processes incoming audio
func (so *SessionOrchestrator) ProcessAudio(ctx context.Context, audioData []byte) error {
so.mu.Lock()
currentState := so.state
so.mu.Unlock()
// Handle barge-in: if speaking, stop and switch to listening
if currentState == StateSpeaking {
so.StopSpeaking()
}
so.SetState(StateListening)
// Transcribe audio
transcript, err := so.asr.Transcribe(ctx, audioData)
if err != nil {
return fmt.Errorf("failed to transcribe: %w", err)
}
// Process transcript
so.SetState(StateThinking)
response, err := so.processTranscript(ctx, transcript)
if err != nil {
return fmt.Errorf("failed to process transcript: %w", err)
}
// Synthesize response
so.SetState(StateSpeaking)
return so.speak(ctx, response)
}
// ProcessText processes incoming text message
func (so *SessionOrchestrator) ProcessText(ctx context.Context, text string) error {
so.SetState(StateThinking)
// Process text
response, err := so.processTranscript(ctx, text)
if err != nil {
return fmt.Errorf("failed to process text: %w", err)
}
// Synthesize response
so.SetState(StateSpeaking)
return so.speak(ctx, response)
}
// processTranscript processes a transcript and generates a response
func (so *SessionOrchestrator) processTranscript(ctx context.Context, transcript string) (string, error) {
// Add user message to conversation
so.conversation = append(so.conversation, llm.Message{
Role: "user",
Content: transcript,
})
// Retrieve relevant documents from RAG
var retrievedDocs []rag.RetrievedDoc
if so.rag != nil {
docs, err := so.rag.Retrieve(ctx, transcript, so.tenantID, 5)
if err == nil {
retrievedDocs = docs
}
}
// Build prompt with RAG context
// Convert retrieved docs to LLM format
ragDocs := make([]llm.RetrievedDoc, len(retrievedDocs))
for i, doc := range retrievedDocs {
ragDocs[i] = llm.RetrievedDoc{
Title: doc.Title,
Content: doc.Content,
URL: doc.URL,
Score: doc.Score,
}
}
// Get available tools (would come from tenant config)
availableTools := []llm.Tool{} // TODO: Get from tenant config
// Call LLM
options := &llm.GenerateOptions{
Temperature: 0.7,
MaxTokens: 500,
Tools: availableTools,
TenantID: so.tenantID,
UserID: so.userID,
ConversationHistory: so.conversation,
}
response, err := so.llm.Generate(ctx, transcript, options)
if err != nil {
return "", fmt.Errorf("failed to generate response: %w", err)
}
// Execute tool calls if any
if len(response.Tools) > 0 && so.tools != nil {
for _, toolCall := range response.Tools {
result, err := so.tools.Execute(ctx, toolCall.Name, toolCall.Arguments, so.userID, so.tenantID)
if err != nil {
// Log error but continue
fmt.Printf("Tool execution error: %v\n", err)
continue
}
// Add tool result to conversation
if result.Success {
so.conversation = append(so.conversation, llm.Message{
Role: "assistant",
Content: fmt.Sprintf("Tool %s executed successfully: %v", toolCall.Name, result.Data),
})
}
}
}
// Add assistant response to conversation
so.conversation = append(so.conversation, llm.Message{
Role: "assistant",
Content: response.Text,
})
return response.Text, nil
}
// speak synthesizes and plays audio
func (so *SessionOrchestrator) speak(ctx context.Context, text string) error {
// Synthesize audio
audioData, err := so.tts.Synthesize(ctx, text)
if err != nil {
return fmt.Errorf("failed to synthesize: %w", err)
}
// Get visemes for avatar
visemes, err := so.tts.GetVisemes(ctx, text)
if err != nil {
// Log error but continue
fmt.Printf("Failed to get visemes: %v\n", err)
}
// TODO: Send audio and visemes to client via WebRTC/WebSocket
_ = audioData
_ = visemes
// Simulate speaking duration
time.Sleep(time.Duration(len(text)*50) * time.Millisecond)
so.SetState(StateIdle)
return nil
}
// StopSpeaking stops current speech (barge-in)
func (so *SessionOrchestrator) StopSpeaking() {
so.mu.Lock()
defer so.mu.Unlock()
if so.state == StateSpeaking {
// Cancel current TTS synthesis
so.cancel()
ctx, cancel := context.WithCancel(context.Background())
so.ctx = ctx
so.cancel = cancel
so.state = StateIdle
}
}
// SetState sets the conversation state
func (so *SessionOrchestrator) SetState(state State) {
so.mu.Lock()
defer so.mu.Unlock()
so.state = state
}
// GetState gets the current conversation state
func (so *SessionOrchestrator) GetState() State {
so.mu.RLock()
defer so.mu.RUnlock()
return so.state
}
// Close closes the session orchestrator
func (so *SessionOrchestrator) Close() {
so.cancel()
}

110
backend/rag/service.go Normal file
View File

@@ -0,0 +1,110 @@
package rag
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
)
// Service provides RAG (Retrieval-Augmented Generation) functionality
type Service interface {
Retrieve(ctx context.Context, query string, tenantID string, topK int) ([]RetrievedDoc, error)
Ingest(ctx context.Context, doc *Document) error
}
// RetrievedDoc represents a retrieved document
type RetrievedDoc struct {
ID string
Title string
Content string
URL string
Score float64
}
// Document represents a document to be ingested
type Document struct {
ID string
TenantID string
Title string
Content string
URL string
Metadata map[string]interface{}
}
// RAGService implements RAG using pgvector
type RAGService struct {
db *pgxpool.Pool
}
// NewRAGService creates a new RAG service
func NewRAGService(db *pgxpool.Pool) *RAGService {
return &RAGService{
db: db,
}
}
// Retrieve retrieves relevant documents
func (s *RAGService) Retrieve(ctx context.Context, query string, tenantID string, topK int) ([]RetrievedDoc, error) {
if topK <= 0 {
topK = 5
}
// TODO: Generate embedding for query
// For now, use simple text search
querySQL := `
SELECT id, title, content, metadata->>'url' as url,
ts_rank(to_tsvector('english', content), plainto_tsquery('english', $1)) as score
FROM knowledge_base
WHERE tenant_id = $2
ORDER BY score DESC
LIMIT $3
`
rows, err := s.db.Query(ctx, querySQL, query, tenantID, topK)
if err != nil {
return nil, fmt.Errorf("failed to query: %w", err)
}
defer rows.Close()
var docs []RetrievedDoc
for rows.Next() {
var doc RetrievedDoc
var url *string
if err := rows.Scan(&doc.ID, &doc.Title, &doc.Content, &url, &doc.Score); err != nil {
continue
}
if url != nil {
doc.URL = *url
}
docs = append(docs, doc)
}
return docs, nil
}
// Ingest ingests a document into the knowledge base
func (s *RAGService) Ingest(ctx context.Context, doc *Document) error {
// TODO: Generate embedding for document content
// For now, just insert without embedding
query := `
INSERT INTO knowledge_base (id, tenant_id, title, content, metadata)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (id) DO UPDATE SET
title = $3,
content = $4,
metadata = $5,
updated_at = NOW()
`
metadata := map[string]interface{}{
"url": doc.URL,
}
for k, v := range doc.Metadata {
metadata[k] = v
}
_, err := s.db.Exec(ctx, query, doc.ID, doc.TenantID, doc.Title, doc.Content, metadata)
return err
}

198
backend/realtime/gateway.go Normal file
View File

@@ -0,0 +1,198 @@
package realtime
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// In production, validate origin properly
return true
},
}
// Gateway handles WebRTC signaling and WebSocket connections
type Gateway struct {
connections map[string]*Connection
mu sync.RWMutex
}
// NewGateway creates a new WebRTC gateway
func NewGateway() *Gateway {
return &Gateway{
connections: make(map[string]*Connection),
}
}
// Connection represents a WebSocket connection for signaling
type Connection struct {
sessionID string
ws *websocket.Conn
send chan []byte
ctx context.Context
cancel context.CancelFunc
}
// HandleWebSocket handles WebSocket upgrade for signaling
func (g *Gateway) HandleWebSocket(w http.ResponseWriter, r *http.Request, sessionID string) error {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return fmt.Errorf("failed to upgrade connection: %w", err)
}
ctx, cancel := context.WithCancel(r.Context())
conn := &Connection{
sessionID: sessionID,
ws: ws,
send: make(chan []byte, 256),
ctx: ctx,
cancel: cancel,
}
g.mu.Lock()
g.connections[sessionID] = conn
g.mu.Unlock()
// Start goroutines
go conn.writePump()
go conn.readPump(g)
return nil
}
// SendMessage sends a message to a specific session
func (g *Gateway) SendMessage(sessionID string, message interface{}) error {
g.mu.RLock()
conn, ok := g.connections[sessionID]
g.mu.RUnlock()
if !ok {
return fmt.Errorf("connection not found for session: %s", sessionID)
}
data, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
select {
case conn.send <- data:
return nil
case <-conn.ctx.Done():
return fmt.Errorf("connection closed")
}
}
// CloseConnection closes a connection
func (g *Gateway) CloseConnection(sessionID string) {
g.mu.Lock()
defer g.mu.Unlock()
if conn, ok := g.connections[sessionID]; ok {
conn.cancel()
conn.ws.Close()
delete(g.connections, sessionID)
}
}
// readPump reads messages from the WebSocket
func (c *Connection) readPump(gateway *Gateway) {
defer func() {
gateway.CloseConnection(c.sessionID)
c.ws.Close()
}()
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
c.ws.SetPongHandler(func(string) error {
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.ws.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
// Handle incoming message (ICE candidates, SDP offers/answers, etc.)
var msg map[string]interface{}
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("Failed to unmarshal message: %v", err)
continue
}
// Route message based on type
msgType, ok := msg["type"].(string)
if !ok {
continue
}
switch msgType {
case "ice-candidate":
// Handle ICE candidate
case "offer":
// Handle SDP offer
case "answer":
// Handle SDP answer
default:
log.Printf("Unknown message type: %s", msgType)
}
}
}
// writePump writes messages to the WebSocket
func (c *Connection) writePump() {
ticker := time.NewTicker(54 * time.Second)
defer func() {
ticker.Stop()
c.ws.Close()
}()
for {
select {
case message, ok := <-c.send:
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.ws.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.ws.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
// Add queued messages
n := len(c.send)
for i := 0; i < n; i++ {
w.Write([]byte{'\n'})
w.Write(<-c.send)
}
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
case <-c.ctx.Done():
return
}
}
}

68
backend/safety/filter.go Normal file
View File

@@ -0,0 +1,68 @@
package safety
import (
"context"
"strings"
)
// Filter filters content for safety
type Filter interface {
Filter(ctx context.Context, text string) (*FilterResult, error)
}
// FilterResult contains filtering results
type FilterResult struct {
Allowed bool
Blocked bool
Redacted string
Categories []string // e.g., "profanity", "pii", "abuse"
}
// ContentFilter implements content filtering
type ContentFilter struct {
blockedWords []string
}
// NewContentFilter creates a new content filter
func NewContentFilter() *ContentFilter {
return &ContentFilter{
blockedWords: []string{
// Add blocked words/phrases
},
}
}
// Filter filters content
func (f *ContentFilter) Filter(ctx context.Context, text string) (*FilterResult, error) {
lowerText := strings.ToLower(text)
var categories []string
// Check for blocked words
for _, word := range f.blockedWords {
if strings.Contains(lowerText, strings.ToLower(word)) {
categories = append(categories, "profanity")
return &FilterResult{
Allowed: false,
Blocked: true,
Redacted: f.redactPII(text),
Categories: categories,
}, nil
}
}
// TODO: Add more sophisticated filtering (ML models, etc.)
return &FilterResult{
Allowed: true,
Blocked: false,
Redacted: f.redactPII(text),
}, nil
}
// redactPII redacts personally identifiable information
func (f *ContentFilter) redactPII(text string) string {
// TODO: Implement PII detection and redaction
// For now, return as-is
return text
}

View File

@@ -0,0 +1,59 @@
package safety
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RateLimiter implements rate limiting
type RateLimiter struct {
redis *redis.Client
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
}
}
// Check checks if a request is within rate limits
func (r *RateLimiter) Check(ctx context.Context, key string, limit int, window time.Duration) (bool, error) {
// Use sliding window log algorithm
now := time.Now()
windowStart := now.Add(-window)
// Count requests in window
count, err := r.redis.ZCount(ctx, key, fmt.Sprintf("%d", windowStart.Unix()), fmt.Sprintf("%d", now.Unix())).Result()
if err != nil {
return false, err
}
if count >= int64(limit) {
return false, nil
}
// Add current request
_, err = r.redis.ZAdd(ctx, key, redis.Z{
Score: float64(now.Unix()),
Member: fmt.Sprintf("%d", now.UnixNano()),
}).Result()
if err != nil {
return false, err
}
// Expire old entries
r.redis.Expire(ctx, key, window)
return true, nil
}
// CheckUser checks rate limit for a user
func (r *RateLimiter) CheckUser(ctx context.Context, tenantID, userID string, limit int, window time.Duration) (bool, error) {
key := fmt.Sprintf("ratelimit:user:%s:%s", tenantID, userID)
return r.Check(ctx, key, limit, window)
}

316
backend/session/session.go Normal file
View File

@@ -0,0 +1,316 @@
package session
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
)
// Session represents a Virtual Banker session
type Session struct {
ID string
TenantID string
UserID string
EphemeralToken string
Config *TenantConfig
CreatedAt time.Time
ExpiresAt time.Time
LastActivityAt time.Time
}
// TenantConfig holds tenant-specific configuration
type TenantConfig struct {
Theme map[string]interface{} `json:"theme"`
AvatarEnabled bool `json:"avatar_enabled"`
Greeting string `json:"greeting"`
AllowedTools []string `json:"allowed_tools"`
Policy *PolicyConfig `json:"policy"`
}
// PolicyConfig holds policy settings
type PolicyConfig struct {
MaxSessionDuration time.Duration `json:"max_session_duration"`
RateLimitPerMinute int `json:"rate_limit_per_minute"`
RequireConsent bool `json:"require_consent"`
}
// Manager manages sessions
type Manager struct {
db *pgxpool.Pool
redis *redis.Client
}
// NewManager creates a new session manager
func NewManager(db *pgxpool.Pool, redisClient *redis.Client) *Manager {
return &Manager{
db: db,
redis: redisClient,
}
}
// CreateSession creates a new session
func (m *Manager) CreateSession(ctx context.Context, tenantID, userID string, authAssertion string) (*Session, error) {
// Validate JWT/auth assertion (simplified - should validate with tenant JWKs)
if authAssertion == "" {
return nil, errors.New("auth assertion required")
}
// Load tenant config
config, err := m.loadTenantConfig(ctx, tenantID)
if err != nil {
return nil, fmt.Errorf("failed to load tenant config: %w", err)
}
// Generate session ID
sessionID, err := generateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
// Generate ephemeral token
ephemeralToken, err := generateEphemeralToken()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral token: %w", err)
}
now := time.Now()
sessionDuration := config.Policy.MaxSessionDuration
if sessionDuration == 0 {
sessionDuration = 30 * time.Minute // default
}
session := &Session{
ID: sessionID,
TenantID: tenantID,
UserID: userID,
EphemeralToken: ephemeralToken,
Config: config,
CreatedAt: now,
ExpiresAt: now.Add(sessionDuration),
LastActivityAt: now,
}
// Save to database
if err := m.saveSessionToDB(ctx, session); err != nil {
return nil, fmt.Errorf("failed to save session: %w", err)
}
// Cache in Redis
if err := m.cacheSession(ctx, session); err != nil {
return nil, fmt.Errorf("failed to cache session: %w", err)
}
return session, nil
}
// GetSession retrieves a session by ID
func (m *Manager) GetSession(ctx context.Context, sessionID string) (*Session, error) {
// Try Redis first
session, err := m.getSessionFromCache(ctx, sessionID)
if err == nil && session != nil {
return session, nil
}
// Fallback to database
session, err = m.getSessionFromDB(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("session not found: %w", err)
}
// Cache it
_ = m.cacheSession(ctx, session)
return session, nil
}
// RefreshToken refreshes the ephemeral token for a session
func (m *Manager) RefreshToken(ctx context.Context, sessionID string) (string, error) {
session, err := m.GetSession(ctx, sessionID)
if err != nil {
return "", err
}
// Check if session is expired
if time.Now().After(session.ExpiresAt) {
return "", errors.New("session expired")
}
// Generate new token
newToken, err := generateEphemeralToken()
if err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
session.EphemeralToken = newToken
session.LastActivityAt = time.Now()
// Update in database and cache
if err := m.saveSessionToDB(ctx, session); err != nil {
return "", fmt.Errorf("failed to update session: %w", err)
}
_ = m.cacheSession(ctx, session)
return newToken, nil
}
// EndSession ends a session
func (m *Manager) EndSession(ctx context.Context, sessionID string) error {
// Remove from Redis
_ = m.redis.Del(ctx, fmt.Sprintf("session:%s", sessionID))
// Mark as ended in database
query := `UPDATE sessions SET ended_at = $1 WHERE id = $2`
_, err := m.db.Exec(ctx, query, time.Now(), sessionID)
return err
}
// loadTenantConfig loads tenant configuration
func (m *Manager) loadTenantConfig(ctx context.Context, tenantID string) (*TenantConfig, error) {
query := `
SELECT theme, avatar_enabled, greeting, allowed_tools, policy
FROM tenants
WHERE id = $1
`
var config TenantConfig
var themeJSON, policyJSON []byte
var allowedToolsJSON []byte
err := m.db.QueryRow(ctx, query, tenantID).Scan(
&themeJSON,
&config.AvatarEnabled,
&config.Greeting,
&allowedToolsJSON,
&policyJSON,
)
if err != nil {
// Return default config if tenant not found
return &TenantConfig{
Theme: map[string]interface{}{"primaryColor": "#0066cc"},
AvatarEnabled: true,
Greeting: "Hello! How can I help you today?",
AllowedTools: []string{},
Policy: &PolicyConfig{
MaxSessionDuration: 30 * time.Minute,
RateLimitPerMinute: 10,
RequireConsent: true,
},
}, nil
}
// Parse JSON fields (simplified - should use json.Unmarshal)
// For now, return default with basic parsing
config.Policy = &PolicyConfig{
MaxSessionDuration: 30 * time.Minute,
RateLimitPerMinute: 10,
RequireConsent: true,
}
return &config, nil
}
// saveSessionToDB saves session to database
func (m *Manager) saveSessionToDB(ctx context.Context, session *Session) error {
query := `
INSERT INTO sessions (id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (id) DO UPDATE SET
ephemeral_token = $4,
last_activity_at = $7
`
_, err := m.db.Exec(ctx, query,
session.ID,
session.TenantID,
session.UserID,
session.EphemeralToken,
session.CreatedAt,
session.ExpiresAt,
session.LastActivityAt,
)
return err
}
// getSessionFromDB retrieves session from database
func (m *Manager) getSessionFromDB(ctx context.Context, sessionID string) (*Session, error) {
query := `
SELECT id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at
FROM sessions
WHERE id = $1 AND ended_at IS NULL
`
var session Session
err := m.db.QueryRow(ctx, query, sessionID).Scan(
&session.ID,
&session.TenantID,
&session.UserID,
&session.EphemeralToken,
&session.CreatedAt,
&session.ExpiresAt,
&session.LastActivityAt,
)
if err != nil {
return nil, err
}
// Load config
config, err := m.loadTenantConfig(ctx, session.TenantID)
if err != nil {
return nil, err
}
session.Config = config
return &session, nil
}
// cacheSession caches session in Redis
func (m *Manager) cacheSession(ctx context.Context, session *Session) error {
key := fmt.Sprintf("session:%s", session.ID)
ttl := time.Until(session.ExpiresAt)
if ttl <= 0 {
return nil
}
// Store as JSON (simplified - should serialize properly)
return m.redis.Set(ctx, key, session.ID, ttl).Err()
}
// getSessionFromCache retrieves session from Redis cache
func (m *Manager) getSessionFromCache(ctx context.Context, sessionID string) (*Session, error) {
key := fmt.Sprintf("session:%s", sessionID)
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
return nil, err
}
if val != sessionID {
return nil, errors.New("cache mismatch")
}
// If cached, fetch full session from DB
return m.getSessionFromDB(ctx, sessionID)
}
// generateSessionID generates a unique session ID
func generateSessionID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// generateEphemeralToken generates an ephemeral token
func generateEphemeralToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}

View File

@@ -0,0 +1,68 @@
package banking
import (
"context"
"github.com/explorer/virtual-banker/backend/tools"
)
// AccountStatusTool gets account status
type AccountStatusTool struct {
client *BankingClient
}
// NewAccountStatusTool creates a new account status tool
func NewAccountStatusTool() *AccountStatusTool {
return &AccountStatusTool{
client: NewBankingClient(getBankingAPIURL()),
}
}
// getBankingAPIURL gets the banking API URL from environment
func getBankingAPIURL() string {
// Default to main API URL
return "http://localhost:8080"
}
// Name returns the tool name
func (t *AccountStatusTool) Name() string {
return "get_account_status"
}
// Description returns the tool description
func (t *AccountStatusTool) Description() string {
return "Get the status of a bank account including balance, transactions, and account details"
}
// Execute executes the tool
func (t *AccountStatusTool) Execute(ctx context.Context, params map[string]interface{}) (*tools.ToolResult, error) {
accountID, ok := params["account_id"].(string)
if !ok || accountID == "" {
return &tools.ToolResult{
Success: false,
Error: "account_id is required",
}, nil
}
// Call banking service
data, err := t.client.GetAccountStatus(ctx, accountID)
if err != nil {
// Fallback to mock data if service unavailable
return &tools.ToolResult{
Success: true,
Data: map[string]interface{}{
"account_id": accountID,
"balance": 10000.00,
"currency": "USD",
"status": "active",
"type": "checking",
"note": "Using fallback data - banking service unavailable",
},
}, nil
}
return &tools.ToolResult{
Success: true,
Data: data,
}, nil
}

View File

@@ -0,0 +1,66 @@
package banking
import (
"context"
"fmt"
"github.com/explorer/virtual-banker/backend/tools"
)
// CreateTicketTool creates a support ticket
type CreateTicketTool struct {
client *BankingClient
}
// NewCreateTicketTool creates a new create ticket tool
func NewCreateTicketTool() *CreateTicketTool {
return &CreateTicketTool{
client: NewBankingClient(getBankingAPIURL()),
}
}
// Name returns the tool name
func (t *CreateTicketTool) Name() string {
return "create_support_ticket"
}
// Description returns the tool description
func (t *CreateTicketTool) Description() string {
return "Create a support ticket for customer service"
}
// Execute executes the tool
func (t *CreateTicketTool) Execute(ctx context.Context, params map[string]interface{}) (*tools.ToolResult, error) {
subject, _ := params["subject"].(string)
details, _ := params["details"].(string)
if subject == "" {
return &tools.ToolResult{
Success: false,
Error: "subject is required",
}, nil
}
// Call banking service
data, err := t.client.CreateTicket(ctx, subject, details)
if err != nil {
// Fallback to mock data if service unavailable
return &tools.ToolResult{
Success: true,
Data: map[string]interface{}{
"ticket_id": fmt.Sprintf("TKT-%d", 12345),
"subject": subject,
"status": "open",
"note": "Using fallback data - banking service unavailable",
},
RequiresConfirmation: false,
}, nil
}
return &tools.ToolResult{
Success: true,
Data: data,
RequiresConfirmation: false,
}, nil
}

View File

@@ -0,0 +1,91 @@
package banking
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
// BankingClient provides access to backend banking services
type BankingClient struct {
baseURL string
httpClient *http.Client
}
// NewBankingClient creates a new banking client
func NewBankingClient(baseURL string) *BankingClient {
return &BankingClient{
baseURL: baseURL,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// GetAccountStatus gets account status from banking service
func (c *BankingClient) GetAccountStatus(ctx context.Context, accountID string) (map[string]interface{}, error) {
url := fmt.Sprintf("%s/api/v1/banking/accounts/%s", c.baseURL, accountID)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result, nil
}
// CreateTicket creates a support ticket
func (c *BankingClient) CreateTicket(ctx context.Context, subject, details string) (map[string]interface{}, error) {
url := fmt.Sprintf("%s/api/v1/banking/tickets", c.baseURL)
payload := map[string]string{
"subject": subject,
"details": details,
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return result, nil
}

View File

@@ -0,0 +1,60 @@
package banking
import (
"context"
"github.com/explorer/virtual-banker/backend/tools"
)
// SubmitPaymentTool submits a payment
type SubmitPaymentTool struct{}
// NewSubmitPaymentTool creates a new submit payment tool
func NewSubmitPaymentTool() *SubmitPaymentTool {
return &SubmitPaymentTool{}
}
// Name returns the tool name
func (t *SubmitPaymentTool) Name() string {
return "submit_payment"
}
// Description returns the tool description
func (t *SubmitPaymentTool) Description() string {
return "Submit a payment transaction (requires confirmation)"
}
// Execute executes the tool
func (t *SubmitPaymentTool) Execute(ctx context.Context, params map[string]interface{}) (*tools.ToolResult, error) {
amount, _ := params["amount"].(float64)
method, _ := params["method"].(string)
if amount <= 0 {
return &tools.ToolResult{
Success: false,
Error: "amount must be greater than 0",
}, nil
}
if method == "" {
return &tools.ToolResult{
Success: false,
Error: "payment method is required",
}, nil
}
// TODO: Call backend/banking/payments/ service
// For now, return mock data
return &tools.ToolResult{
Success: true,
Data: map[string]interface{}{
"payment_id": "PAY-11111",
"amount": amount,
"method": method,
"status": "pending_confirmation",
"transaction_id": "TXN-22222",
},
RequiresConfirmation: true, // Payments always require confirmation
}, nil
}

View File

@@ -0,0 +1,62 @@
package banking
import (
"context"
"time"
"github.com/explorer/virtual-banker/backend/tools"
)
// ScheduleAppointmentTool schedules an appointment
type ScheduleAppointmentTool struct{}
// NewScheduleAppointmentTool creates a new schedule appointment tool
func NewScheduleAppointmentTool() *ScheduleAppointmentTool {
return &ScheduleAppointmentTool{}
}
// Name returns the tool name
func (t *ScheduleAppointmentTool) Name() string {
return "schedule_appointment"
}
// Description returns the tool description
func (t *ScheduleAppointmentTool) Description() string {
return "Schedule an appointment with a bank representative"
}
// Execute executes the tool
func (t *ScheduleAppointmentTool) Execute(ctx context.Context, params map[string]interface{}) (*tools.ToolResult, error) {
datetime, _ := params["datetime"].(string)
reason, _ := params["reason"].(string)
if datetime == "" {
return &tools.ToolResult{
Success: false,
Error: "datetime is required",
}, nil
}
// Parse datetime
_, err := time.Parse(time.RFC3339, datetime)
if err != nil {
return &tools.ToolResult{
Success: false,
Error: "invalid datetime format (use RFC3339)",
}, nil
}
// TODO: Call backend/banking/ service to schedule appointment
// For now, return mock data
return &tools.ToolResult{
Success: true,
Data: map[string]interface{}{
"appointment_id": "APT-67890",
"datetime": datetime,
"reason": reason,
"status": "scheduled",
},
RequiresConfirmation: true, // Appointments require confirmation
}, nil
}

89
backend/tools/executor.go Normal file
View File

@@ -0,0 +1,89 @@
package tools
import (
"context"
"fmt"
)
// Executor executes tools
type Executor struct {
registry *Registry
auditLog AuditLogger
}
// NewExecutor creates a new tool executor
func NewExecutor(registry *Registry, auditLog AuditLogger) *Executor {
return &Executor{
registry: registry,
auditLog: auditLog,
}
}
// Execute executes a tool
func (e *Executor) Execute(ctx context.Context, toolName string, params map[string]interface{}, userID, tenantID string) (*ToolResult, error) {
tool, err := e.registry.Get(toolName)
if err != nil {
return nil, err
}
// Log execution attempt
e.auditLog.LogToolExecution(ctx, &ToolExecutionLog{
ToolName: toolName,
UserID: userID,
TenantID: tenantID,
Params: params,
Status: "executing",
})
// Execute tool
result, err := tool.Execute(ctx, params)
if err != nil {
e.auditLog.LogToolExecution(ctx, &ToolExecutionLog{
ToolName: toolName,
UserID: userID,
TenantID: tenantID,
Params: params,
Status: "failed",
Error: err.Error(),
})
return nil, err
}
// Log result
e.auditLog.LogToolExecution(ctx, &ToolExecutionLog{
ToolName: toolName,
UserID: userID,
TenantID: tenantID,
Params: params,
Status: "completed",
Result: result.Data,
})
return result, nil
}
// AuditLogger logs tool executions
type AuditLogger interface {
LogToolExecution(ctx context.Context, log *ToolExecutionLog)
}
// ToolExecutionLog represents a tool execution log entry
type ToolExecutionLog struct {
ToolName string
UserID string
TenantID string
Params map[string]interface{}
Status string
Error string
Result interface{}
}
// MockAuditLogger is a mock audit logger
type MockAuditLogger struct{}
// LogToolExecution logs a tool execution
func (m *MockAuditLogger) LogToolExecution(ctx context.Context, log *ToolExecutionLog) {
// Mock implementation - in production, write to database
fmt.Printf("Tool execution: %s by %s (%s) - %s\n", log.ToolName, log.UserID, log.TenantID, log.Status)
}

73
backend/tools/registry.go Normal file
View File

@@ -0,0 +1,73 @@
package tools
import (
"context"
"fmt"
)
// Tool represents an executable tool
type Tool interface {
Name() string
Description() string
Execute(ctx context.Context, params map[string]interface{}) (*ToolResult, error)
}
// ToolResult represents the result of tool execution
type ToolResult struct {
Success bool
Data interface{}
Error string
RequiresConfirmation bool
}
// Registry manages available tools
type Registry struct {
tools map[string]Tool
}
// NewRegistry creates a new tool registry
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]Tool),
}
}
// Register registers a tool
func (r *Registry) Register(tool Tool) {
r.tools[tool.Name()] = tool
}
// Get gets a tool by name
func (r *Registry) Get(name string) (Tool, error) {
tool, ok := r.tools[name]
if !ok {
return nil, fmt.Errorf("tool not found: %s", name)
}
return tool, nil
}
// List returns all registered tools
func (r *Registry) List() []Tool {
tools := make([]Tool, 0, len(r.tools))
for _, tool := range r.tools {
tools = append(tools, tool)
}
return tools
}
// GetAllowedTools returns tools allowed for a tenant
func (r *Registry) GetAllowedTools(allowedNames []string) []Tool {
allowedSet := make(map[string]bool)
for _, name := range allowedNames {
allowedSet[name] = true
}
var tools []Tool
for _, tool := range r.tools {
if allowedSet[tool.Name()] {
tools = append(tools, tool)
}
}
return tools
}

View File

@@ -0,0 +1,329 @@
package tts
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ElevenLabsTTSService integrates with ElevenLabs TTS API
type ElevenLabsTTSService struct {
apiKey string
voiceID string
modelID string
baseURL string
httpClient *http.Client
defaultVoiceConfig *VoiceConfig
}
// VoiceConfig holds ElevenLabs voice configuration
type VoiceConfig struct {
Stability float64 `json:"stability"`
SimilarityBoost float64 `json:"similarity_boost"`
Style float64 `json:"style,omitempty"`
UseSpeakerBoost bool `json:"use_speaker_boost,omitempty"`
}
// ElevenLabsRequest represents the request body for ElevenLabs API
type ElevenLabsRequest struct {
Text string `json:"text"`
ModelID string `json:"model_id,omitempty"`
VoiceSettings VoiceConfig `json:"voice_settings,omitempty"`
}
// NewElevenLabsTTSService creates a new ElevenLabs TTS service
func NewElevenLabsTTSService(apiKey, voiceID string) *ElevenLabsTTSService {
return &ElevenLabsTTSService{
apiKey: apiKey,
voiceID: voiceID,
modelID: "eleven_multilingual_v2", // Default model
baseURL: "https://api.elevenlabs.io/v1",
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
defaultVoiceConfig: &VoiceConfig{
Stability: 0.5,
SimilarityBoost: 0.75,
UseSpeakerBoost: true,
},
}
}
// SetModelID sets the model ID for synthesis
func (s *ElevenLabsTTSService) SetModelID(modelID string) {
s.modelID = modelID
}
// SetVoiceConfig sets the default voice configuration
func (s *ElevenLabsTTSService) SetVoiceConfig(config *VoiceConfig) {
s.defaultVoiceConfig = config
}
// Synthesize synthesizes text to audio using ElevenLabs REST API
func (s *ElevenLabsTTSService) Synthesize(ctx context.Context, text string) ([]byte, error) {
return s.SynthesizeWithConfig(ctx, text, s.defaultVoiceConfig)
}
// SynthesizeWithConfig synthesizes text to audio with custom voice configuration
func (s *ElevenLabsTTSService) SynthesizeWithConfig(ctx context.Context, text string, config *VoiceConfig) ([]byte, error) {
if s.apiKey == "" {
return nil, fmt.Errorf("ElevenLabs API key not configured")
}
if s.voiceID == "" {
return nil, fmt.Errorf("ElevenLabs voice ID not configured")
}
if text == "" {
return nil, fmt.Errorf("text cannot be empty")
}
// Use default config if none provided
if config == nil {
config = s.defaultVoiceConfig
}
// Prepare request body
reqBody := ElevenLabsRequest{
Text: text,
ModelID: s.modelID,
VoiceSettings: *config,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Build request URL
url := fmt.Sprintf("%s/text-to-speech/%s", s.baseURL, s.voiceID)
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "audio/mpeg")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("xi-api-key", s.apiKey)
// Execute request with retry logic
var resp *http.Response
maxRetries := 3
for i := 0; i < maxRetries; i++ {
resp, err = s.httpClient.Do(req)
if err == nil && resp.StatusCode == http.StatusOK {
break
}
if err != nil {
if i < maxRetries-1 {
// Exponential backoff
backoff := time.Duration(i+1) * time.Second
time.Sleep(backoff)
continue
}
return nil, fmt.Errorf("failed to call ElevenLabs API after %d retries: %w", maxRetries, err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
bodyBytes, _ := io.ReadAll(bytes.NewReader([]byte{}))
if resp.Body != nil {
bodyBytes, _ = io.ReadAll(resp.Body)
}
// Retry on 5xx errors
if resp.StatusCode >= 500 && i < maxRetries-1 {
backoff := time.Duration(i+1) * time.Second
time.Sleep(backoff)
continue
}
return nil, fmt.Errorf("ElevenLabs API error: status %d, body: %s", resp.StatusCode, string(bodyBytes))
}
}
defer resp.Body.Close()
// Read audio data
audioData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read audio data: %w", err)
}
return audioData, nil
}
// SynthesizeStream synthesizes text to audio using ElevenLabs streaming API
func (s *ElevenLabsTTSService) SynthesizeStream(ctx context.Context, text string) (io.Reader, error) {
return s.SynthesizeStreamWithConfig(ctx, text, s.defaultVoiceConfig)
}
// SynthesizeStreamWithConfig synthesizes text to audio stream with custom voice configuration
func (s *ElevenLabsTTSService) SynthesizeStreamWithConfig(ctx context.Context, text string, config *VoiceConfig) (io.Reader, error) {
if s.apiKey == "" {
return nil, fmt.Errorf("ElevenLabs API key not configured")
}
if s.voiceID == "" {
return nil, fmt.Errorf("ElevenLabs voice ID not configured")
}
if text == "" {
return nil, fmt.Errorf("text cannot be empty")
}
// Use default config if none provided
if config == nil {
config = s.defaultVoiceConfig
}
// Prepare request body
reqBody := ElevenLabsRequest{
Text: text,
ModelID: s.modelID,
VoiceSettings: *config,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Build request URL for streaming
url := fmt.Sprintf("%s/text-to-speech/%s/stream", s.baseURL, s.voiceID)
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "audio/mpeg")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("xi-api-key", s.apiKey)
// Execute request
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call ElevenLabs streaming API: %w", err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("ElevenLabs streaming API error: status %d, body: %s", resp.StatusCode, string(bodyBytes))
}
// Return stream reader (caller is responsible for closing)
return resp.Body, nil
}
// GetVisemes returns viseme events for lip sync
// ElevenLabs doesn't provide viseme data directly, so we use phoneme-to-viseme mapping
func (s *ElevenLabsTTSService) GetVisemes(ctx context.Context, text string) ([]VisemeEvent, error) {
if text == "" {
return nil, fmt.Errorf("text cannot be empty")
}
// Use phoneme-to-viseme mapping to generate viseme events
// This is a simplified implementation - in production, you might want to use
// a more sophisticated phoneme-to-viseme mapping service or library
visemes := s.generateVisemesFromText(text)
return visemes, nil
}
// generateVisemesFromText generates viseme events from text using basic phoneme-to-viseme mapping
// This is a simplified implementation. For production, consider using:
// - A dedicated phoneme-to-viseme mapping service
// - A TTS provider that provides phoneme timing data (e.g., Azure TTS with SSML)
// - Integration with a speech analysis library
func (s *ElevenLabsTTSService) generateVisemesFromText(text string) []VisemeEvent {
// Basic phoneme-to-viseme mapping
phonemeToViseme := map[string]string{
// Vowels
"aa": "aa", "ae": "aa", "ah": "aa", "ao": "aa", "aw": "aa",
"ay": "aa", "eh": "ee", "er": "er", "ey": "ee", "ih": "ee",
"iy": "ee", "ow": "oh", "oy": "oh", "uh": "ou", "uw": "ou",
// Consonants
"b": "aa", "p": "aa", "m": "aa",
"f": "ee", "v": "ee",
"th": "ee",
"d": "aa", "t": "aa", "n": "aa", "l": "aa",
"k": "aa", "g": "aa", "ng": "aa",
"s": "ee", "z": "ee",
"sh": "ee", "zh": "ee", "ch": "ee", "jh": "ee",
"y": "ee",
"w": "ou",
"r": "er",
"h": "sil",
"sil": "sil", "sp": "sil",
}
// Simple word-to-phoneme approximation
// In production, use a proper TTS API that provides phoneme timing or a phoneme-to-viseme service
words := strings.Fields(strings.ToLower(text))
visemes := []VisemeEvent{}
currentTime := 0.0
durationPerWord := 0.3 // Approximate duration per word in seconds
initialPause := 0.1
// Initial silence
visemes = append(visemes, VisemeEvent{
Viseme: "sil",
StartTime: currentTime,
EndTime: currentTime + initialPause,
Phoneme: "sil",
})
currentTime += initialPause
// Generate visemes for each word
for _, word := range words {
// Simple approximation: map first phoneme to viseme
viseme := "aa" // default
if len(word) > 0 {
firstChar := string(word[0])
if mapped, ok := phonemeToViseme[firstChar]; ok {
viseme = mapped
} else {
// Map common starting consonants
switch firstChar {
case "a", "e", "i", "o", "u":
viseme = "aa"
default:
viseme = "aa"
}
}
}
visemes = append(visemes, VisemeEvent{
Viseme: viseme,
StartTime: currentTime,
EndTime: currentTime + durationPerWord,
Phoneme: word,
})
currentTime += durationPerWord
// Small pause between words
visemes = append(visemes, VisemeEvent{
Viseme: "sil",
StartTime: currentTime,
EndTime: currentTime + 0.05,
Phoneme: "sil",
})
currentTime += 0.05
}
// Final silence
visemes = append(visemes, VisemeEvent{
Viseme: "sil",
StartTime: currentTime,
EndTime: currentTime + 0.1,
Phoneme: "sil",
})
return visemes
}

58
backend/tts/service.go Normal file
View File

@@ -0,0 +1,58 @@
package tts
import (
"context"
"fmt"
"io"
)
// Service provides text-to-speech functionality
type Service interface {
SynthesizeStream(ctx context.Context, text string) (io.Reader, error)
Synthesize(ctx context.Context, text string) ([]byte, error)
GetVisemes(ctx context.Context, text string) ([]VisemeEvent, error)
}
// VisemeEvent represents a viseme (lip shape) event for lip sync
type VisemeEvent struct {
Viseme string `json:"viseme"` // e.g., "sil", "aa", "ee", "oh", "ou"
StartTime float64 `json:"start_time"`
EndTime float64 `json:"end_time"`
Phoneme string `json:"phoneme,omitempty"`
}
// MockTTSService is a mock implementation for development
type MockTTSService struct{}
// NewMockTTSService creates a new mock TTS service
func NewMockTTSService() *MockTTSService {
return &MockTTSService{}
}
// SynthesizeStream synthesizes text to audio stream
func (s *MockTTSService) SynthesizeStream(ctx context.Context, text string) (io.Reader, error) {
// Mock implementation - in production, integrate with ElevenLabs, Azure TTS, etc.
// For now, return empty reader
return io.NopCloser(io.Reader(nil)), nil
}
// Synthesize synthesizes text to audio
func (s *MockTTSService) Synthesize(ctx context.Context, text string) ([]byte, error) {
// Mock implementation
return []byte{}, nil
}
// GetVisemes returns viseme events for lip sync
func (s *MockTTSService) GetVisemes(ctx context.Context, text string) ([]VisemeEvent, error) {
// Mock implementation - return basic visemes
return []VisemeEvent{
{Viseme: "sil", StartTime: 0.0, EndTime: 0.1},
{Viseme: "aa", StartTime: 0.1, EndTime: 0.3},
{Viseme: "ee", StartTime: 0.3, EndTime: 0.5},
}, nil
}
// ElevenLabsTTSService integrates with ElevenLabs (implementation in elevenlabs-adapter.go)
// This interface definition is kept for backwards compatibility
// The actual implementation is in elevenlabs-adapter.go