package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "time" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" ) // Auth handles user authentication type Auth struct { db *pgxpool.Pool } // NewAuth creates a new auth handler func NewAuth(db *pgxpool.Pool) *Auth { return &Auth{db: db} } // User represents a user type User struct { ID string Email string Username string CreatedAt time.Time } // RegisterUser registers a new user func (a *Auth) RegisterUser(ctx context.Context, email, username, password string) (*User, error) { // Hash password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } // Insert user query := ` INSERT INTO users (email, username, password_hash) VALUES ($1, $2, $3) RETURNING id, email, username, created_at ` var user User err = a.db.QueryRow(ctx, query, email, username, hashedPassword).Scan( &user.ID, &user.Email, &user.Username, &user.CreatedAt, ) if err != nil { return nil, fmt.Errorf("failed to create user: %w", err) } return &user, nil } // AuthenticateUser authenticates a user func (a *Auth) AuthenticateUser(ctx context.Context, email, password string) (*User, error) { var user User var passwordHash string query := `SELECT id, email, username, password_hash, created_at FROM users WHERE email = $1` err := a.db.QueryRow(ctx, query, email).Scan( &user.ID, &user.Email, &user.Username, &passwordHash, &user.CreatedAt, ) if err != nil { return nil, fmt.Errorf("invalid credentials") } // Verify password if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil { return nil, fmt.Errorf("invalid credentials") } return &user, nil } // GenerateAPIKey generates a new API key for a user func (a *Auth) GenerateAPIKey(ctx context.Context, userID, name string, tier string) (string, error) { // Generate random key keyBytes := make([]byte, 32) if _, err := rand.Read(keyBytes); err != nil { return "", fmt.Errorf("failed to generate key: %w", err) } apiKey := "ek_" + hex.EncodeToString(keyBytes) // Hash key for storage hashedKey := sha256.Sum256([]byte(apiKey)) hashedKeyHex := hex.EncodeToString(hashedKey[:]) // Determine rate limits based on tier var rateLimitPerSecond, rateLimitPerMinute int switch tier { case "free": rateLimitPerSecond = 5 rateLimitPerMinute = 100 case "pro": rateLimitPerSecond = 20 rateLimitPerMinute = 1000 case "enterprise": rateLimitPerSecond = 100 rateLimitPerMinute = 10000 default: rateLimitPerSecond = 5 rateLimitPerMinute = 100 } // Store API key query := ` INSERT INTO api_keys (user_id, key_hash, name, tier, rate_limit_per_second, rate_limit_per_minute) VALUES ($1, $2, $3, $4, $5, $6) ` _, err := a.db.Exec(ctx, query, userID, hashedKeyHex, name, tier, rateLimitPerSecond, rateLimitPerMinute) if err != nil { return "", fmt.Errorf("failed to store API key: %w", err) } return apiKey, nil } // ValidateAPIKey validates an API key func (a *Auth) ValidateAPIKey(ctx context.Context, apiKey string) (string, error) { hashedKey := sha256.Sum256([]byte(apiKey)) hashedKeyHex := hex.EncodeToString(hashedKey[:]) var userID string var revoked bool query := `SELECT user_id, revoked FROM api_keys WHERE key_hash = $1` err := a.db.QueryRow(ctx, query, hashedKeyHex).Scan(&userID, &revoked) if err != nil { return "", fmt.Errorf("invalid API key") } if revoked { return "", fmt.Errorf("API key revoked") } // Update last used a.db.Exec(ctx, `UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1`, hashedKeyHex) return userID, nil }