package session import ( "context" "crypto/rand" "encoding/base64" "errors" "fmt" "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" ) // Session represents a Virtual Banker session type Session struct { ID string TenantID string UserID string EphemeralToken string Config *TenantConfig CreatedAt time.Time ExpiresAt time.Time LastActivityAt time.Time } // TenantConfig holds tenant-specific configuration type TenantConfig struct { Theme map[string]interface{} `json:"theme"` AvatarEnabled bool `json:"avatar_enabled"` Greeting string `json:"greeting"` AllowedTools []string `json:"allowed_tools"` Policy *PolicyConfig `json:"policy"` } // PolicyConfig holds policy settings type PolicyConfig struct { MaxSessionDuration time.Duration `json:"max_session_duration"` RateLimitPerMinute int `json:"rate_limit_per_minute"` RequireConsent bool `json:"require_consent"` } // Manager manages sessions type Manager struct { db *pgxpool.Pool redis *redis.Client } // NewManager creates a new session manager func NewManager(db *pgxpool.Pool, redisClient *redis.Client) *Manager { return &Manager{ db: db, redis: redisClient, } } // CreateSession creates a new session func (m *Manager) CreateSession(ctx context.Context, tenantID, userID string, authAssertion string) (*Session, error) { // Validate JWT/auth assertion (simplified - should validate with tenant JWKs) if authAssertion == "" { return nil, errors.New("auth assertion required") } // Load tenant config config, err := m.loadTenantConfig(ctx, tenantID) if err != nil { return nil, fmt.Errorf("failed to load tenant config: %w", err) } // Generate session ID sessionID, err := generateSessionID() if err != nil { return nil, fmt.Errorf("failed to generate session ID: %w", err) } // Generate ephemeral token ephemeralToken, err := generateEphemeralToken() if err != nil { return nil, fmt.Errorf("failed to generate ephemeral token: %w", err) } now := time.Now() sessionDuration := config.Policy.MaxSessionDuration if sessionDuration == 0 { sessionDuration = 30 * time.Minute // default } session := &Session{ ID: sessionID, TenantID: tenantID, UserID: userID, EphemeralToken: ephemeralToken, Config: config, CreatedAt: now, ExpiresAt: now.Add(sessionDuration), LastActivityAt: now, } // Save to database if err := m.saveSessionToDB(ctx, session); err != nil { return nil, fmt.Errorf("failed to save session: %w", err) } // Cache in Redis if err := m.cacheSession(ctx, session); err != nil { return nil, fmt.Errorf("failed to cache session: %w", err) } return session, nil } // GetSession retrieves a session by ID func (m *Manager) GetSession(ctx context.Context, sessionID string) (*Session, error) { // Try Redis first session, err := m.getSessionFromCache(ctx, sessionID) if err == nil && session != nil { return session, nil } // Fallback to database session, err = m.getSessionFromDB(ctx, sessionID) if err != nil { return nil, fmt.Errorf("session not found: %w", err) } // Cache it _ = m.cacheSession(ctx, session) return session, nil } // RefreshToken refreshes the ephemeral token for a session func (m *Manager) RefreshToken(ctx context.Context, sessionID string) (string, error) { session, err := m.GetSession(ctx, sessionID) if err != nil { return "", err } // Check if session is expired if time.Now().After(session.ExpiresAt) { return "", errors.New("session expired") } // Generate new token newToken, err := generateEphemeralToken() if err != nil { return "", fmt.Errorf("failed to generate token: %w", err) } session.EphemeralToken = newToken session.LastActivityAt = time.Now() // Update in database and cache if err := m.saveSessionToDB(ctx, session); err != nil { return "", fmt.Errorf("failed to update session: %w", err) } _ = m.cacheSession(ctx, session) return newToken, nil } // EndSession ends a session func (m *Manager) EndSession(ctx context.Context, sessionID string) error { // Remove from Redis _ = m.redis.Del(ctx, fmt.Sprintf("session:%s", sessionID)) // Mark as ended in database query := `UPDATE sessions SET ended_at = $1 WHERE id = $2` _, err := m.db.Exec(ctx, query, time.Now(), sessionID) return err } // loadTenantConfig loads tenant configuration func (m *Manager) loadTenantConfig(ctx context.Context, tenantID string) (*TenantConfig, error) { query := ` SELECT theme, avatar_enabled, greeting, allowed_tools, policy FROM tenants WHERE id = $1 ` var config TenantConfig var themeJSON, policyJSON []byte var allowedToolsJSON []byte err := m.db.QueryRow(ctx, query, tenantID).Scan( &themeJSON, &config.AvatarEnabled, &config.Greeting, &allowedToolsJSON, &policyJSON, ) if err != nil { // Return default config if tenant not found return &TenantConfig{ Theme: map[string]interface{}{"primaryColor": "#0066cc"}, AvatarEnabled: true, Greeting: "Hello! How can I help you today?", AllowedTools: []string{}, Policy: &PolicyConfig{ MaxSessionDuration: 30 * time.Minute, RateLimitPerMinute: 10, RequireConsent: true, }, }, nil } // Parse JSON fields (simplified - should use json.Unmarshal) // For now, return default with basic parsing config.Policy = &PolicyConfig{ MaxSessionDuration: 30 * time.Minute, RateLimitPerMinute: 10, RequireConsent: true, } return &config, nil } // saveSessionToDB saves session to database func (m *Manager) saveSessionToDB(ctx context.Context, session *Session) error { query := ` INSERT INTO sessions (id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (id) DO UPDATE SET ephemeral_token = $4, last_activity_at = $7 ` _, err := m.db.Exec(ctx, query, session.ID, session.TenantID, session.UserID, session.EphemeralToken, session.CreatedAt, session.ExpiresAt, session.LastActivityAt, ) return err } // getSessionFromDB retrieves session from database func (m *Manager) getSessionFromDB(ctx context.Context, sessionID string) (*Session, error) { query := ` SELECT id, tenant_id, user_id, ephemeral_token, created_at, expires_at, last_activity_at FROM sessions WHERE id = $1 AND ended_at IS NULL ` var session Session err := m.db.QueryRow(ctx, query, sessionID).Scan( &session.ID, &session.TenantID, &session.UserID, &session.EphemeralToken, &session.CreatedAt, &session.ExpiresAt, &session.LastActivityAt, ) if err != nil { return nil, err } // Load config config, err := m.loadTenantConfig(ctx, session.TenantID) if err != nil { return nil, err } session.Config = config return &session, nil } // cacheSession caches session in Redis func (m *Manager) cacheSession(ctx context.Context, session *Session) error { key := fmt.Sprintf("session:%s", session.ID) ttl := time.Until(session.ExpiresAt) if ttl <= 0 { return nil } // Store as JSON (simplified - should serialize properly) return m.redis.Set(ctx, key, session.ID, ttl).Err() } // getSessionFromCache retrieves session from Redis cache func (m *Manager) getSessionFromCache(ctx context.Context, sessionID string) (*Session, error) { key := fmt.Sprintf("session:%s", sessionID) val, err := m.redis.Get(ctx, key).Result() if err != nil { return nil, err } if val != sessionID { return nil, errors.New("cache mismatch") } // If cached, fetch full session from DB return m.getSessionFromDB(ctx, sessionID) } // generateSessionID generates a unique session ID func generateSessionID() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil } // generateEphemeralToken generates an ephemeral token func generateEphemeralToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil }