package auth import ( "context" "crypto/rand" "encoding/hex" "errors" "fmt" "strings" "time" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/golang-jwt/jwt/v4" "github.com/jackc/pgx/v5/pgxpool" ) var ( ErrWalletAuthStorageNotInitialized = errors.New("wallet authentication storage is not initialized; run migration 0010_track_schema") ErrWalletNonceNotFoundOrExpired = errors.New("nonce not found or expired") ErrWalletNonceExpired = errors.New("nonce expired") ErrWalletNonceInvalid = errors.New("invalid nonce") ErrJWTRevoked = errors.New("token has been revoked") ErrJWTRevocationStorageMissing = errors.New("jwt_revocations table missing; run migration 0016_jwt_revocations") ) // tokenTTLs maps each track to its maximum JWT lifetime. Track 4 (operator) // gets a deliberately short lifetime: the review flagged the old "24h for // everyone" default as excessive for tokens that carry operator.write.* // permissions. Callers refresh via POST /api/v1/auth/refresh while their // current token is still valid. var tokenTTLs = map[int]time.Duration{ 1: 12 * time.Hour, 2: 8 * time.Hour, 3: 4 * time.Hour, 4: 60 * time.Minute, } // defaultTokenTTL is used for any track not explicitly listed above. const defaultTokenTTL = 12 * time.Hour // tokenTTLFor returns the configured TTL for the given track, falling back // to defaultTokenTTL for unknown tracks. Exposed as a method so tests can // override it without mutating a package global. func tokenTTLFor(track int) time.Duration { if ttl, ok := tokenTTLs[track]; ok { return ttl } return defaultTokenTTL } func isMissingJWTRevocationTableError(err error) bool { return err != nil && strings.Contains(err.Error(), `relation "jwt_revocations" does not exist`) } // newJTI returns a random JWT ID used for revocation tracking. 16 random // bytes = 128 bits of entropy, hex-encoded. func newJTI() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("generate jti: %w", err) } return hex.EncodeToString(b), nil } // WalletAuth handles wallet-based authentication type WalletAuth struct { db *pgxpool.Pool jwtSecret []byte } // NewWalletAuth creates a new wallet auth handler func NewWalletAuth(db *pgxpool.Pool, jwtSecret []byte) *WalletAuth { return &WalletAuth{ db: db, jwtSecret: jwtSecret, } } func isMissingWalletNonceTableError(err error) bool { return err != nil && strings.Contains(err.Error(), `relation "wallet_nonces" does not exist`) } // NonceRequest represents a nonce request type NonceRequest struct { Address string `json:"address"` } // NonceResponse represents a nonce response type NonceResponse struct { Nonce string `json:"nonce"` ExpiresAt time.Time `json:"expires_at"` } // WalletAuthRequest represents a wallet authentication request type WalletAuthRequest struct { Address string `json:"address"` Signature string `json:"signature"` Nonce string `json:"nonce"` } // WalletAuthResponse represents a wallet authentication response type WalletAuthResponse struct { Token string `json:"token"` ExpiresAt time.Time `json:"expires_at"` Track int `json:"track"` Permissions []string `json:"permissions"` } // GenerateNonce generates a random nonce for wallet authentication func (w *WalletAuth) GenerateNonce(ctx context.Context, address string) (*NonceResponse, error) { // Validate address format if !common.IsHexAddress(address) { return nil, fmt.Errorf("invalid address format") } // Normalize address to checksum format addr := common.HexToAddress(address) normalizedAddr := addr.Hex() // Generate random nonce nonceBytes := make([]byte, 32) if _, err := rand.Read(nonceBytes); err != nil { return nil, fmt.Errorf("failed to generate nonce: %w", err) } nonce := hex.EncodeToString(nonceBytes) // Store nonce in database with expiration (5 minutes) expiresAt := time.Now().Add(5 * time.Minute) query := ` INSERT INTO wallet_nonces (address, nonce, expires_at) VALUES ($1, $2, $3) ON CONFLICT (address) DO UPDATE SET nonce = EXCLUDED.nonce, expires_at = EXCLUDED.expires_at, created_at = NOW() ` _, err := w.db.Exec(ctx, query, normalizedAddr, nonce, expiresAt) if err != nil { if isMissingWalletNonceTableError(err) { return nil, ErrWalletAuthStorageNotInitialized } return nil, fmt.Errorf("failed to store nonce: %w", err) } return &NonceResponse{ Nonce: nonce, ExpiresAt: expiresAt, }, nil } // AuthenticateWallet authenticates a wallet using signature func (w *WalletAuth) AuthenticateWallet(ctx context.Context, req *WalletAuthRequest) (*WalletAuthResponse, error) { // Validate address format if !common.IsHexAddress(req.Address) { return nil, fmt.Errorf("invalid address format") } // Normalize address addr := common.HexToAddress(req.Address) normalizedAddr := addr.Hex() // Verify nonce var storedNonce string var expiresAt time.Time query := `SELECT nonce, expires_at FROM wallet_nonces WHERE address = $1` err := w.db.QueryRow(ctx, query, normalizedAddr).Scan(&storedNonce, &expiresAt) if err != nil { if isMissingWalletNonceTableError(err) { return nil, ErrWalletAuthStorageNotInitialized } return nil, ErrWalletNonceNotFoundOrExpired } if time.Now().After(expiresAt) { return nil, ErrWalletNonceExpired } if storedNonce != req.Nonce { return nil, ErrWalletNonceInvalid } // Verify signature message := fmt.Sprintf("Sign this message to authenticate with SolaceScan.\n\nNonce: %s", req.Nonce) messageHash := accounts.TextHash([]byte(message)) sigBytes, err := decodeWalletSignature(req.Signature) if err != nil { return nil, fmt.Errorf("invalid signature format: %w", err) } // Recover public key from signature if sigBytes[64] >= 27 { sigBytes[64] -= 27 } pubKey, err := crypto.SigToPub(messageHash, sigBytes) if err != nil { return nil, fmt.Errorf("failed to recover public key: %w", err) } recoveredAddr := crypto.PubkeyToAddress(*pubKey) if recoveredAddr.Hex() != normalizedAddr { return nil, fmt.Errorf("signature does not match address") } // Get or create user and track level track, err := w.getUserTrack(ctx, normalizedAddr) if err != nil { return nil, fmt.Errorf("failed to get user track: %w", err) } // Generate JWT token token, expiresAt, err := w.generateJWT(normalizedAddr, track) if err != nil { return nil, fmt.Errorf("failed to generate token: %w", err) } // Delete used nonce w.db.Exec(ctx, `DELETE FROM wallet_nonces WHERE address = $1`, normalizedAddr) // Get permissions for track permissions := getPermissionsForTrack(track) return &WalletAuthResponse{ Token: token, ExpiresAt: expiresAt, Track: track, Permissions: permissions, }, nil } // getUserTrack gets the track level for a user address func (w *WalletAuth) getUserTrack(ctx context.Context, address string) (int, error) { // Check if user exists in operator_roles (Track 4) var track int var approved bool query := `SELECT track_level, approved FROM operator_roles WHERE address = $1` err := w.db.QueryRow(ctx, query, address).Scan(&track, &approved) if err == nil && approved { return track, nil } // Check if user is approved for Track 2 or 3 // For now, default to Track 1 (public) // In production, you'd have an approval table return 1, nil } // generateJWT generates a JWT token with track, jti, exp, and iat claims. // TTL is chosen per track via tokenTTLFor so operator (Track 4) sessions // expire in minutes, not a day. func (w *WalletAuth) generateJWT(address string, track int) (string, time.Time, error) { jti, err := newJTI() if err != nil { return "", time.Time{}, err } expiresAt := time.Now().Add(tokenTTLFor(track)) claims := jwt.MapClaims{ "address": address, "track": track, "jti": jti, "exp": expiresAt.Unix(), "iat": time.Now().Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(w.jwtSecret) if err != nil { return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err) } return tokenString, expiresAt, nil } // ValidateJWT validates a JWT token and returns the address and track. // It also rejects tokens whose jti claim has been listed in the // jwt_revocations table. func (w *WalletAuth) ValidateJWT(tokenString string) (string, int, error) { address, track, _, _, err := w.parseJWT(tokenString) if err != nil { return "", 0, err } // If we have a database, enforce revocation and re-resolve the track // (an operator revoking a wallet's Track 4 approval should not wait // for the token to expire before losing the elevated permission). if w.db != nil { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() jti, _ := w.jtiFromToken(tokenString) if jti != "" { revoked, revErr := w.isJTIRevoked(ctx, jti) if revErr != nil && !errors.Is(revErr, ErrJWTRevocationStorageMissing) { return "", 0, fmt.Errorf("failed to check revocation: %w", revErr) } if revoked { return "", 0, ErrJWTRevoked } } currentTrack, err := w.getUserTrack(ctx, address) if err != nil { return "", 0, fmt.Errorf("failed to resolve current track: %w", err) } if currentTrack < track { track = currentTrack } } return address, track, nil } // parseJWT performs signature verification and claim extraction without // any database round-trip. Shared between ValidateJWT and RefreshJWT. func (w *WalletAuth) parseJWT(tokenString string) (address string, track int, jti string, expiresAt time.Time, err error) { token, perr := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return w.jwtSecret, nil }) if perr != nil { return "", 0, "", time.Time{}, fmt.Errorf("failed to parse token: %w", perr) } if !token.Valid { return "", 0, "", time.Time{}, fmt.Errorf("invalid token") } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return "", 0, "", time.Time{}, fmt.Errorf("invalid token claims") } address, ok = claims["address"].(string) if !ok { return "", 0, "", time.Time{}, fmt.Errorf("address not found in token") } trackFloat, ok := claims["track"].(float64) if !ok { return "", 0, "", time.Time{}, fmt.Errorf("track not found in token") } track = int(trackFloat) if v, ok := claims["jti"].(string); ok { jti = v } if expFloat, ok := claims["exp"].(float64); ok { expiresAt = time.Unix(int64(expFloat), 0) } return address, track, jti, expiresAt, nil } // jtiFromToken parses the jti claim without doing a fresh signature check. // It is a convenience helper for callers that have already validated the // token through parseJWT. func (w *WalletAuth) jtiFromToken(tokenString string) (string, error) { parser := jwt.Parser{} token, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{}) if err != nil { return "", err } claims, ok := token.Claims.(jwt.MapClaims) if !ok { return "", fmt.Errorf("invalid claims") } v, _ := claims["jti"].(string) return v, nil } // isJTIRevoked checks whether the given jti appears in jwt_revocations. // Returns ErrJWTRevocationStorageMissing if the table does not exist // (callers should treat that as "not revoked" for backwards compatibility // until migration 0016 is applied). func (w *WalletAuth) isJTIRevoked(ctx context.Context, jti string) (bool, error) { var exists bool err := w.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM jwt_revocations WHERE jti = $1)`, jti, ).Scan(&exists) if err != nil { if isMissingJWTRevocationTableError(err) { return false, ErrJWTRevocationStorageMissing } return false, err } return exists, nil } // RevokeJWT records the token's jti in jwt_revocations. Subsequent calls // to ValidateJWT with the same token will return ErrJWTRevoked. Idempotent // on duplicate jti. func (w *WalletAuth) RevokeJWT(ctx context.Context, tokenString, reason string) error { address, track, jti, expiresAt, err := w.parseJWT(tokenString) if err != nil { return err } if jti == "" { // Legacy tokens issued before PR #8 don't carry a jti; there is // nothing to revoke server-side. Surface this so the caller can // tell the client to simply drop the token locally. return fmt.Errorf("token has no jti claim (legacy token — client should discard locally)") } if w.db == nil { return fmt.Errorf("wallet auth has no database; cannot revoke") } if strings.TrimSpace(reason) == "" { reason = "logout" } _, err = w.db.Exec(ctx, `INSERT INTO jwt_revocations (jti, address, track, token_expires_at, reason) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (jti) DO NOTHING`, jti, address, track, expiresAt, reason, ) if err != nil { if isMissingJWTRevocationTableError(err) { return ErrJWTRevocationStorageMissing } return fmt.Errorf("record revocation: %w", err) } return nil } // RefreshJWT issues a new token for the same address+track if the current // token is valid (signed, unexpired, not revoked) and revokes the current // token so it cannot be replayed. Returns the new token and its exp. func (w *WalletAuth) RefreshJWT(ctx context.Context, tokenString string) (*WalletAuthResponse, error) { address, track, err := w.ValidateJWT(tokenString) if err != nil { return nil, err } // Revoke the old token before issuing a new one. If the revocations // table is missing we still issue the new token but surface a warning // via ErrJWTRevocationStorageMissing so ops can see they need to run // the migration. var revokeErr error if w.db != nil { revokeErr = w.RevokeJWT(ctx, tokenString, "refresh") if revokeErr != nil && !errors.Is(revokeErr, ErrJWTRevocationStorageMissing) { return nil, revokeErr } } newToken, expiresAt, err := w.generateJWT(address, track) if err != nil { return nil, err } return &WalletAuthResponse{ Token: newToken, ExpiresAt: expiresAt, Track: track, Permissions: getPermissionsForTrack(track), }, revokeErr } func decodeWalletSignature(signature string) ([]byte, error) { if len(signature) < 2 || !strings.EqualFold(signature[:2], "0x") { return nil, fmt.Errorf("signature must start with 0x") } raw := signature[2:] if len(raw) != 130 { return nil, fmt.Errorf("invalid signature length") } sigBytes, err := hex.DecodeString(raw) if err != nil { return nil, err } if len(sigBytes) != 65 { return nil, fmt.Errorf("invalid signature length") } return sigBytes, nil } // getPermissionsForTrack returns permissions for a track level func getPermissionsForTrack(track int) []string { permissions := []string{ "explorer.read.blocks", "explorer.read.transactions", "explorer.read.address.basic", "explorer.read.bridge.status", "weth.wrap", "weth.unwrap", } if track >= 2 { permissions = append(permissions, "explorer.read.address.full", "explorer.read.tokens", "explorer.read.tx_history", "explorer.read.internal_txs", "explorer.search.enhanced", ) } if track >= 3 { permissions = append(permissions, "analytics.read.flows", "analytics.read.bridge", "analytics.read.token_distribution", "analytics.read.address_risk", ) } if track >= 4 { permissions = append(permissions, "operator.read.bridge_events", "operator.read.validators", "operator.read.contracts", "operator.read.protocol_state", "operator.write.bridge_control", ) } return permissions }