package state import ( "context" "encoding/json" "fmt" "time" "github.com/jackc/pgx/v5/pgxpool" ) // StateManager manages conversation state type StateManager struct { db *pgxpool.Pool } // NewStateManager creates a new state manager func NewStateManager(db *pgxpool.Pool) *StateManager { return &StateManager{db: db} } // ConversationState represents conversation state type ConversationState struct { SessionID string UserID string Workflow string Step string Context map[string]interface{} CreatedAt time.Time UpdatedAt time.Time ExpiresAt time.Time } // SaveState saves conversation state func (s *StateManager) SaveState(ctx context.Context, state *ConversationState) error { contextJSON, err := json.Marshal(state.Context) if err != nil { return fmt.Errorf("failed to marshal context: %w", err) } query := ` INSERT INTO conversation_states ( session_id, user_id, workflow, step, context, created_at, updated_at, expires_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (session_id) DO UPDATE SET workflow = $3, step = $4, context = $5, updated_at = $7, expires_at = $8 ` _, err = s.db.Exec(ctx, query, state.SessionID, state.UserID, state.Workflow, state.Step, contextJSON, state.CreatedAt, time.Now(), state.ExpiresAt, ) return err } // GetState gets conversation state func (s *StateManager) GetState(ctx context.Context, sessionID string) (*ConversationState, error) { query := ` SELECT session_id, user_id, workflow, step, context, created_at, updated_at, expires_at FROM conversation_states WHERE session_id = $1 ` var state ConversationState var contextJSON []byte err := s.db.QueryRow(ctx, query, sessionID).Scan( &state.SessionID, &state.UserID, &state.Workflow, &state.Step, &contextJSON, &state.CreatedAt, &state.UpdatedAt, &state.ExpiresAt, ) if err != nil { return nil, fmt.Errorf("failed to get state: %w", err) } if err := json.Unmarshal(contextJSON, &state.Context); err != nil { return nil, fmt.Errorf("failed to unmarshal context: %w", err) } return &state, nil }