Files

183 lines
4.7 KiB
Go

package auth
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// RoleManager handles role-based access control
type RoleManager struct {
db *pgxpool.Pool
}
// NewRoleManager creates a new role manager
func NewRoleManager(db *pgxpool.Pool) *RoleManager {
return &RoleManager{db: db}
}
// UserRole represents a user's role and track assignment
type UserRole struct {
Address string
Track int
Roles []string
Approved bool
ApprovedBy string
ApprovedAt time.Time
}
// AssignTrack assigns a track level to a user address
func (r *RoleManager) AssignTrack(ctx context.Context, address string, track int, approvedBy string) error {
if track < 1 || track > 4 {
return fmt.Errorf("invalid track level: %d (must be 1-4)", track)
}
query := `
INSERT INTO operator_roles (address, track_level, approved, approved_by, approved_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (address) DO UPDATE SET
track_level = EXCLUDED.track_level,
approved = EXCLUDED.approved,
approved_by = EXCLUDED.approved_by,
approved_at = EXCLUDED.approved_at,
updated_at = NOW()
`
_, err := r.db.Exec(ctx, query, address, track, true, approvedBy, time.Now())
if err != nil {
return fmt.Errorf("failed to assign track: %w", err)
}
return nil
}
// GetUserRole gets the role and track for a user address
func (r *RoleManager) GetUserRole(ctx context.Context, address string) (*UserRole, error) {
var role UserRole
query := `
SELECT address, track_level, roles, approved, approved_by, approved_at
FROM operator_roles
WHERE address = $1
`
err := r.db.QueryRow(ctx, query, address).Scan(
&role.Address,
&role.Track,
&role.Roles,
&role.Approved,
&role.ApprovedBy,
&role.ApprovedAt,
)
if err != nil {
// User not found, return default Track 1
return &UserRole{
Address: address,
Track: 1,
Roles: []string{},
Approved: false,
}, nil
}
return &role, nil
}
// ApproveUser approves a user for their assigned track
func (r *RoleManager) ApproveUser(ctx context.Context, address string, approvedBy string) error {
query := `
UPDATE operator_roles
SET approved = TRUE,
approved_by = $2,
approved_at = NOW(),
updated_at = NOW()
WHERE address = $1
`
result, err := r.db.Exec(ctx, query, address, approvedBy)
if err != nil {
return fmt.Errorf("failed to approve user: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("user not found")
}
return nil
}
// RevokeUser revokes a user's approval
func (r *RoleManager) RevokeUser(ctx context.Context, address string) error {
query := `
UPDATE operator_roles
SET approved = FALSE,
approved_at = NULL,
updated_at = NOW()
WHERE address = $1
`
result, err := r.db.Exec(ctx, query, address)
if err != nil {
return fmt.Errorf("failed to revoke user: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("user not found")
}
return nil
}
// AddIPWhitelist adds an IP address to the whitelist for an operator
func (r *RoleManager) AddIPWhitelist(ctx context.Context, operatorAddress string, ipAddress string, description string) error {
query := `
INSERT INTO operator_ip_whitelist (operator_address, ip_address, description)
VALUES ($1, $2, $3)
ON CONFLICT (operator_address, ip_address) DO UPDATE SET
description = EXCLUDED.description
`
_, err := r.db.Exec(ctx, query, operatorAddress, ipAddress, description)
if err != nil {
return fmt.Errorf("failed to add IP to whitelist: %w", err)
}
return nil
}
// IsIPWhitelisted checks if an IP address is whitelisted for an operator
func (r *RoleManager) IsIPWhitelisted(ctx context.Context, operatorAddress string, ipAddress string) (bool, error) {
var count int
query := `
SELECT COUNT(*)
FROM operator_ip_whitelist
WHERE operator_address = $1 AND ip_address = $2
`
err := r.db.QueryRow(ctx, query, operatorAddress, ipAddress).Scan(&count)
if err != nil {
return false, fmt.Errorf("failed to check IP whitelist: %w", err)
}
return count > 0, nil
}
// LogOperatorEvent logs an operator event for audit purposes
func (r *RoleManager) LogOperatorEvent(ctx context.Context, eventType string, chainID *int, operatorAddress string, targetResource string, action string, details map[string]interface{}, ipAddress string, userAgent string) error {
query := `
INSERT INTO operator_events (event_type, chain_id, operator_address, target_resource, action, details, ip_address, user_agent)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`
// Convert details map to JSONB
detailsJSON := map[string]interface{}(details)
_, err := r.db.Exec(ctx, query, eventType, chainID, operatorAddress, targetResource, action, detailsJSON, ipAddress, userAgent)
if err != nil {
return fmt.Errorf("failed to log operator event: %w", err)
}
return nil
}