package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "time" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" ) // Auth handles user authentication type Auth struct { db *pgxpool.Pool } // NewAuth creates a new auth handler func NewAuth(db *pgxpool.Pool) *Auth { return &Auth{db: db} } // User represents a user type User struct { ID string Email string Username string CreatedAt time.Time } type APIKeyInfo struct { ID string `json:"id"` Name string `json:"name"` Tier string `json:"tier"` ProductSlug string `json:"productSlug"` Scopes []string `json:"scopes"` MonthlyQuota int `json:"monthlyQuota"` RequestsUsed int `json:"requestsUsed"` Approved bool `json:"approved"` ApprovedAt *time.Time `json:"approvedAt"` RateLimitPerSecond int `json:"rateLimitPerSecond"` RateLimitPerMinute int `json:"rateLimitPerMinute"` LastUsedAt *time.Time `json:"lastUsedAt"` ExpiresAt *time.Time `json:"expiresAt"` Revoked bool `json:"revoked"` CreatedAt time.Time `json:"createdAt"` } type ValidatedAPIKey struct { UserID string `json:"userId"` APIKeyID string `json:"apiKeyId"` Name string `json:"name"` Tier string `json:"tier"` ProductSlug string `json:"productSlug"` Scopes []string `json:"scopes"` MonthlyQuota int `json:"monthlyQuota"` RequestsUsed int `json:"requestsUsed"` RateLimitPerSecond int `json:"rateLimitPerSecond"` RateLimitPerMinute int `json:"rateLimitPerMinute"` LastUsedAt *time.Time `json:"lastUsedAt"` ExpiresAt *time.Time `json:"expiresAt"` } type ProductSubscription struct { ID string `json:"id"` ProductSlug string `json:"productSlug"` Tier string `json:"tier"` Status string `json:"status"` MonthlyQuota int `json:"monthlyQuota"` RequestsUsed int `json:"requestsUsed"` RequiresApproval bool `json:"requiresApproval"` ApprovedAt *time.Time `json:"approvedAt"` ApprovedBy *string `json:"approvedBy"` Notes *string `json:"notes"` CreatedAt time.Time `json:"createdAt"` } type APIKeyUsageLog struct { ID int64 `json:"id"` APIKeyID string `json:"apiKeyId"` KeyName string `json:"keyName"` ProductSlug string `json:"productSlug"` MethodName string `json:"methodName"` RequestCount int `json:"requestCount"` LastIP *string `json:"lastIp"` CreatedAt time.Time `json:"createdAt"` } func (a *Auth) ListAllSubscriptions(ctx context.Context, status string) ([]ProductSubscription, error) { query := ` SELECT id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), requires_approval, approved_at, approved_by, notes, created_at FROM user_product_subscriptions ` args := []any{} if status != "" { query += ` WHERE status = $1` args = append(args, status) } query += ` ORDER BY created_at DESC` rows, err := a.db.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to list all subscriptions: %w", err) } defer rows.Close() subs := make([]ProductSubscription, 0) for rows.Next() { var sub ProductSubscription var approvedAt *time.Time var approvedBy, notes *string if err := rows.Scan( &sub.ID, &sub.ProductSlug, &sub.Tier, &sub.Status, &sub.MonthlyQuota, &sub.RequestsUsed, &sub.RequiresApproval, &approvedAt, &approvedBy, ¬es, &sub.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to scan subscription: %w", err) } sub.ApprovedAt = approvedAt sub.ApprovedBy = approvedBy sub.Notes = notes subs = append(subs, sub) } return subs, nil } func (a *Auth) UpdateSubscriptionStatus( ctx context.Context, subscriptionID string, status string, approvedBy string, notes string, ) (*ProductSubscription, error) { query := ` UPDATE user_product_subscriptions SET status = $2, approved_at = CASE WHEN $2 = 'active' THEN NOW() ELSE approved_at END, approved_by = CASE WHEN $2 = 'active' THEN NULLIF($3, '') ELSE approved_by END, notes = CASE WHEN NULLIF($4, '') IS NOT NULL THEN $4 ELSE notes END, updated_at = NOW() WHERE id = $1 RETURNING id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), requires_approval, approved_at, approved_by, notes, created_at ` var sub ProductSubscription var approvedAt *time.Time var approvedByPtr, notesPtr *string if err := a.db.QueryRow(ctx, query, subscriptionID, status, approvedBy, notes).Scan( &sub.ID, &sub.ProductSlug, &sub.Tier, &sub.Status, &sub.MonthlyQuota, &sub.RequestsUsed, &sub.RequiresApproval, &approvedAt, &approvedByPtr, ¬esPtr, &sub.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to update subscription: %w", err) } sub.ApprovedAt = approvedAt sub.ApprovedBy = approvedByPtr sub.Notes = notesPtr return &sub, nil } // RegisterUser registers a new user func (a *Auth) RegisterUser(ctx context.Context, email, username, password string) (*User, error) { // Hash password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } // Insert user query := ` INSERT INTO users (email, username, password_hash) VALUES ($1, $2, $3) RETURNING id, email, username, created_at ` var user User err = a.db.QueryRow(ctx, query, email, username, hashedPassword).Scan( &user.ID, &user.Email, &user.Username, &user.CreatedAt, ) if err != nil { return nil, fmt.Errorf("failed to create user: %w", err) } return &user, nil } // AuthenticateUser authenticates a user func (a *Auth) AuthenticateUser(ctx context.Context, email, password string) (*User, error) { var user User var passwordHash string query := `SELECT id, email, username, password_hash, created_at FROM users WHERE email = $1` err := a.db.QueryRow(ctx, query, email).Scan( &user.ID, &user.Email, &user.Username, &passwordHash, &user.CreatedAt, ) if err != nil { return nil, fmt.Errorf("invalid credentials") } // Verify password if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil { return nil, fmt.Errorf("invalid credentials") } _, _ = a.db.Exec(ctx, `UPDATE users SET last_login_at = NOW(), updated_at = NOW() WHERE id = $1`, user.ID) return &user, nil } // GenerateAPIKey generates a new API key for a user func (a *Auth) GenerateAPIKey(ctx context.Context, userID, name string, tier string) (string, error) { return a.GenerateScopedAPIKey(ctx, userID, name, tier, "", nil, 0, false, 0) } func (a *Auth) GenerateScopedAPIKey(ctx context.Context, userID, name string, tier string, productSlug string, scopes []string, monthlyQuota int, approved bool, expiresDays int) (string, error) { // Generate random key keyBytes := make([]byte, 32) if _, err := rand.Read(keyBytes); err != nil { return "", fmt.Errorf("failed to generate key: %w", err) } apiKey := "ek_" + hex.EncodeToString(keyBytes) // Hash key for storage hashedKey := sha256.Sum256([]byte(apiKey)) hashedKeyHex := hex.EncodeToString(hashedKey[:]) // Determine rate limits based on tier var rateLimitPerSecond, rateLimitPerMinute int switch tier { case "free": rateLimitPerSecond = 5 rateLimitPerMinute = 100 case "pro": rateLimitPerSecond = 20 rateLimitPerMinute = 1000 case "enterprise": rateLimitPerSecond = 100 rateLimitPerMinute = 10000 default: rateLimitPerSecond = 5 rateLimitPerMinute = 100 } var expiresAt *time.Time if expiresDays > 0 { expires := time.Now().Add(time.Duration(expiresDays) * 24 * time.Hour) expiresAt = &expires } // Store API key query := ` INSERT INTO api_keys ( user_id, key_hash, name, tier, product_slug, scopes, monthly_quota, rate_limit_per_second, rate_limit_per_minute, approved, approved_at, expires_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, CASE WHEN $10 THEN NOW() ELSE NULL END, $11) ` _, err := a.db.Exec(ctx, query, userID, hashedKeyHex, name, tier, productSlug, scopes, monthlyQuota, rateLimitPerSecond, rateLimitPerMinute, approved, expiresAt) if err != nil { return "", fmt.Errorf("failed to store API key: %w", err) } return apiKey, nil } // ValidateAPIKey validates an API key func (a *Auth) ValidateAPIKey(ctx context.Context, apiKey string) (string, error) { hashedKey := sha256.Sum256([]byte(apiKey)) hashedKeyHex := hex.EncodeToString(hashedKey[:]) var userID string var revoked, approved bool var expiresAt *time.Time query := `SELECT user_id, revoked, approved, expires_at FROM api_keys WHERE key_hash = $1` err := a.db.QueryRow(ctx, query, hashedKeyHex).Scan(&userID, &revoked, &approved, &expiresAt) if err != nil { return "", fmt.Errorf("invalid API key") } if revoked { return "", fmt.Errorf("API key revoked") } if !approved { return "", fmt.Errorf("API key pending approval") } if expiresAt != nil && time.Now().After(*expiresAt) { return "", fmt.Errorf("API key expired") } // Update last used a.db.Exec(ctx, `UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1`, hashedKeyHex) return userID, nil } func (a *Auth) ValidateAPIKeyDetailed(ctx context.Context, apiKey string, methodName string, requestCount int, lastIPAddress string) (*ValidatedAPIKey, error) { hashedKey := sha256.Sum256([]byte(apiKey)) hashedKeyHex := hex.EncodeToString(hashedKey[:]) query := ` SELECT id, user_id, COALESCE(name, ''), tier, COALESCE(product_slug, ''), COALESCE(scopes, ARRAY[]::TEXT[]), COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), approved, COALESCE(rate_limit_per_second, 0), COALESCE(rate_limit_per_minute, 0), last_used_at, expires_at, revoked FROM api_keys WHERE key_hash = $1 ` var validated ValidatedAPIKey var approved, revoked bool var lastUsedAt, expiresAt *time.Time if err := a.db.QueryRow(ctx, query, hashedKeyHex).Scan( &validated.APIKeyID, &validated.UserID, &validated.Name, &validated.Tier, &validated.ProductSlug, &validated.Scopes, &validated.MonthlyQuota, &validated.RequestsUsed, &approved, &validated.RateLimitPerSecond, &validated.RateLimitPerMinute, &lastUsedAt, &expiresAt, &revoked, ); err != nil { return nil, fmt.Errorf("invalid API key") } if revoked { return nil, fmt.Errorf("API key revoked") } if !approved { return nil, fmt.Errorf("API key pending approval") } if expiresAt != nil && time.Now().After(*expiresAt) { return nil, fmt.Errorf("API key expired") } if requestCount <= 0 { requestCount = 1 } _, _ = a.db.Exec(ctx, ` UPDATE api_keys SET last_used_at = NOW(), requests_used = COALESCE(requests_used, 0) + $2, last_ip_address = NULLIF($3, '')::inet WHERE key_hash = $1 `, hashedKeyHex, requestCount, lastIPAddress) _, _ = a.db.Exec(ctx, ` INSERT INTO api_key_usage_logs (api_key_id, product_slug, method_name, request_count, window_start, window_end, last_ip_address) VALUES ($1, NULLIF($2, ''), NULLIF($3, ''), $4, NOW(), NOW(), NULLIF($5, '')::inet) `, validated.APIKeyID, validated.ProductSlug, methodName, requestCount, lastIPAddress) validated.RequestsUsed += requestCount validated.LastUsedAt = lastUsedAt validated.ExpiresAt = expiresAt return &validated, nil } func (a *Auth) ListAPIKeys(ctx context.Context, userID string) ([]APIKeyInfo, error) { rows, err := a.db.Query(ctx, ` SELECT id, COALESCE(name, ''), tier, COALESCE(product_slug, ''), COALESCE(scopes, ARRAY[]::TEXT[]), COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), approved, approved_at, COALESCE(rate_limit_per_second, 0), COALESCE(rate_limit_per_minute, 0), last_used_at, expires_at, revoked, created_at FROM api_keys WHERE user_id = $1 ORDER BY created_at DESC `, userID) if err != nil { return nil, fmt.Errorf("failed to list API keys: %w", err) } defer rows.Close() keys := make([]APIKeyInfo, 0) for rows.Next() { var key APIKeyInfo var lastUsedAt, expiresAt, approvedAt *time.Time if err := rows.Scan( &key.ID, &key.Name, &key.Tier, &key.ProductSlug, &key.Scopes, &key.MonthlyQuota, &key.RequestsUsed, &key.Approved, &approvedAt, &key.RateLimitPerSecond, &key.RateLimitPerMinute, &lastUsedAt, &expiresAt, &key.Revoked, &key.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to scan API key: %w", err) } key.ApprovedAt = approvedAt key.LastUsedAt = lastUsedAt key.ExpiresAt = expiresAt keys = append(keys, key) } return keys, nil } func (a *Auth) ListUsageLogs(ctx context.Context, userID string, limit int) ([]APIKeyUsageLog, error) { if limit <= 0 { limit = 20 } rows, err := a.db.Query(ctx, ` SELECT logs.id, logs.api_key_id, COALESCE(keys.name, ''), COALESCE(logs.product_slug, ''), COALESCE(logs.method_name, ''), logs.request_count, CASE WHEN logs.last_ip_address IS NOT NULL THEN host(logs.last_ip_address) ELSE NULL END, logs.created_at FROM api_key_usage_logs logs INNER JOIN api_keys keys ON keys.id = logs.api_key_id WHERE keys.user_id = $1 ORDER BY logs.created_at DESC LIMIT $2 `, userID, limit) if err != nil { return nil, fmt.Errorf("failed to list usage logs: %w", err) } defer rows.Close() entries := make([]APIKeyUsageLog, 0) for rows.Next() { var entry APIKeyUsageLog var lastIP *string if err := rows.Scan( &entry.ID, &entry.APIKeyID, &entry.KeyName, &entry.ProductSlug, &entry.MethodName, &entry.RequestCount, &lastIP, &entry.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to scan usage log: %w", err) } entry.LastIP = lastIP entries = append(entries, entry) } return entries, nil } func (a *Auth) ListAllUsageLogs(ctx context.Context, productSlug string, limit int) ([]APIKeyUsageLog, error) { if limit <= 0 { limit = 50 } query := ` SELECT logs.id, logs.api_key_id, COALESCE(keys.name, ''), COALESCE(logs.product_slug, ''), COALESCE(logs.method_name, ''), logs.request_count, CASE WHEN logs.last_ip_address IS NOT NULL THEN host(logs.last_ip_address) ELSE NULL END, logs.created_at FROM api_key_usage_logs logs INNER JOIN api_keys keys ON keys.id = logs.api_key_id ` args := []any{} if productSlug != "" { query += ` WHERE logs.product_slug = $1` args = append(args, productSlug) } query += fmt.Sprintf(" ORDER BY logs.created_at DESC LIMIT $%d", len(args)+1) args = append(args, limit) rows, err := a.db.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to list all usage logs: %w", err) } defer rows.Close() entries := make([]APIKeyUsageLog, 0) for rows.Next() { var entry APIKeyUsageLog var lastIP *string if err := rows.Scan( &entry.ID, &entry.APIKeyID, &entry.KeyName, &entry.ProductSlug, &entry.MethodName, &entry.RequestCount, &lastIP, &entry.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to scan usage log: %w", err) } entry.LastIP = lastIP entries = append(entries, entry) } return entries, nil } func (a *Auth) RevokeAPIKey(ctx context.Context, userID, keyID string) error { tag, err := a.db.Exec(ctx, `UPDATE api_keys SET revoked = true WHERE id = $1 AND user_id = $2`, keyID, userID) if err != nil { return fmt.Errorf("failed to revoke API key: %w", err) } if tag.RowsAffected() == 0 { return fmt.Errorf("api key not found") } return nil } func (a *Auth) UpsertProductSubscription( ctx context.Context, userID, productSlug, tier, status string, monthlyQuota int, requiresApproval bool, approvedBy string, notes string, ) (*ProductSubscription, error) { query := ` INSERT INTO user_product_subscriptions ( user_id, product_slug, tier, status, monthly_quota, requires_approval, approved_at, approved_by, notes ) VALUES ($1, $2, $3, $4, $5, $6, CASE WHEN $4 = 'active' THEN NOW() ELSE NULL END, NULLIF($7, ''), NULLIF($8, '')) ON CONFLICT (user_id, product_slug) DO UPDATE SET tier = EXCLUDED.tier, status = EXCLUDED.status, monthly_quota = EXCLUDED.monthly_quota, requires_approval = EXCLUDED.requires_approval, approved_at = CASE WHEN EXCLUDED.status = 'active' THEN NOW() ELSE user_product_subscriptions.approved_at END, approved_by = NULLIF(EXCLUDED.approved_by, ''), notes = NULLIF(EXCLUDED.notes, ''), updated_at = NOW() RETURNING id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), requires_approval, approved_at, approved_by, notes, created_at ` var sub ProductSubscription var approvedAt *time.Time var approvedByPtr, notesPtr *string if err := a.db.QueryRow(ctx, query, userID, productSlug, tier, status, monthlyQuota, requiresApproval, approvedBy, notes).Scan( &sub.ID, &sub.ProductSlug, &sub.Tier, &sub.Status, &sub.MonthlyQuota, &sub.RequestsUsed, &sub.RequiresApproval, &approvedAt, &approvedByPtr, ¬esPtr, &sub.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to save subscription: %w", err) } sub.ApprovedAt = approvedAt sub.ApprovedBy = approvedByPtr sub.Notes = notesPtr return &sub, nil } func (a *Auth) ListSubscriptions(ctx context.Context, userID string) ([]ProductSubscription, error) { rows, err := a.db.Query(ctx, ` SELECT id, product_slug, tier, status, COALESCE(monthly_quota, 0), COALESCE(requests_used, 0), requires_approval, approved_at, approved_by, notes, created_at FROM user_product_subscriptions WHERE user_id = $1 ORDER BY created_at DESC `, userID) if err != nil { return nil, fmt.Errorf("failed to list subscriptions: %w", err) } defer rows.Close() subs := make([]ProductSubscription, 0) for rows.Next() { var sub ProductSubscription var approvedAt *time.Time var approvedBy, notes *string if err := rows.Scan( &sub.ID, &sub.ProductSlug, &sub.Tier, &sub.Status, &sub.MonthlyQuota, &sub.RequestsUsed, &sub.RequiresApproval, &approvedAt, &approvedBy, ¬es, &sub.CreatedAt, ); err != nil { return nil, fmt.Errorf("failed to scan subscription: %w", err) } sub.ApprovedAt = approvedAt sub.ApprovedBy = approvedBy sub.Notes = notes subs = append(subs, sub) } return subs, nil }