package storage import ( "context" "errors" "net/netip" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/alchemistkay/guestguard/internal/domain" ) // EmailVerificationRepo manages single-use email verification tokens. type EmailVerificationRepo struct { pool *pgxpool.Pool } func NewEmailVerificationRepo(db *DB) *EmailVerificationRepo { return &EmailVerificationRepo{pool: db.Pool} } func (r *EmailVerificationRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error { _, err := r.pool.Exec(ctx, ` INSERT INTO email_verification_tokens (token_hash, user_id, expires_at) VALUES ($1, $2, $3) `, hash, userID, expiresAt) return err } // Consume atomically marks the token as used and returns the owning user_id. // Returns ErrAuthTokenNotFound / ErrAuthTokenConsumed / ErrAuthTokenExpired. func (r *EmailVerificationRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) { const q = ` UPDATE email_verification_tokens SET consumed_at = now() WHERE token_hash = $1 AND consumed_at IS NULL AND expires_at > now() RETURNING user_id ` var uid uuid.UUID if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil { if errors.Is(err, pgx.ErrNoRows) { return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool, "SELECT consumed_at, expires_at FROM email_verification_tokens WHERE token_hash=$1", hash) } return uuid.Nil, err } return uid, nil } // PasswordResetRepo manages single-use password-reset tokens. type PasswordResetRepo struct { pool *pgxpool.Pool } func NewPasswordResetRepo(db *DB) *PasswordResetRepo { return &PasswordResetRepo{pool: db.Pool} } func (r *PasswordResetRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error { _, err := r.pool.Exec(ctx, ` INSERT INTO password_reset_tokens (token_hash, user_id, expires_at) VALUES ($1, $2, $3) `, hash, userID, expiresAt) return err } func (r *PasswordResetRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) { const q = ` UPDATE password_reset_tokens SET consumed_at = now() WHERE token_hash = $1 AND consumed_at IS NULL AND expires_at > now() RETURNING user_id ` var uid uuid.UUID if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil { if errors.Is(err, pgx.ErrNoRows) { return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool, "SELECT consumed_at, expires_at FROM password_reset_tokens WHERE token_hash=$1", hash) } return uuid.Nil, err } return uid, nil } // RefreshTokenRepo manages refresh-token rows. Refresh tokens are rotated: // every refresh issues a new token and revokes the old one, recording the // chain in `replaced_by` so we can detect replay (a revoked token being // presented again triggers a family-wide revocation). type RefreshTokenRepo struct { pool *pgxpool.Pool } func NewRefreshTokenRepo(db *DB) *RefreshTokenRepo { return &RefreshTokenRepo{pool: db.Pool} } type RefreshToken struct { Hash string UserID uuid.UUID ExpiresAt time.Time RevokedAt *time.Time ReplacedBy *string UserAgent string IPAddress *netip.Addr CreatedAt time.Time } type CreateRefreshTokenParams struct { Hash string UserID uuid.UUID ExpiresAt time.Time UserAgent string IPAddress string } func (r *RefreshTokenRepo) Create(ctx context.Context, p CreateRefreshTokenParams) error { ip := parseIP(p.IPAddress) _, err := r.pool.Exec(ctx, ` INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address) VALUES ($1, $2, $3, NULLIF($4, ''), $5) `, p.Hash, p.UserID, p.ExpiresAt, p.UserAgent, ip) return err } func (r *RefreshTokenRepo) Get(ctx context.Context, hash string) (*RefreshToken, error) { const q = ` SELECT token_hash, user_id, expires_at, revoked_at, replaced_by, COALESCE(user_agent, ''), host(ip_address), created_at FROM refresh_tokens WHERE token_hash = $1 ` var rt RefreshToken var ipText *string if err := r.pool.QueryRow(ctx, q, hash).Scan( &rt.Hash, &rt.UserID, &rt.ExpiresAt, &rt.RevokedAt, &rt.ReplacedBy, &rt.UserAgent, &ipText, &rt.CreatedAt, ); err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, domain.ErrAuthTokenNotFound } return nil, err } if ipText != nil && *ipText != "" { if addr, err := netip.ParseAddr(*ipText); err == nil { rt.IPAddress = &addr } } return &rt, nil } // Rotate atomically (in a transaction) marks the old token revoked and // inserts the new one with replaced_by set. Returns ErrAuthTokenNotFound or // ErrRefreshTokenRevoked if the old token is missing or already revoked. func (r *RefreshTokenRepo) Rotate(ctx context.Context, oldHash string, next CreateRefreshTokenParams) error { tx, err := r.pool.Begin(ctx) if err != nil { return err } defer tx.Rollback(ctx) var revokedAt *time.Time var userID uuid.UUID var expiresAt time.Time err = tx.QueryRow(ctx, ` SELECT user_id, expires_at, revoked_at FROM refresh_tokens WHERE token_hash = $1 FOR UPDATE `, oldHash).Scan(&userID, &expiresAt, &revokedAt) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.ErrAuthTokenNotFound } return err } if revokedAt != nil { // Replay of a revoked refresh token — revoke the entire family. if _, err := tx.Exec(ctx, ` UPDATE refresh_tokens SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL `, userID); err != nil { return err } if err := tx.Commit(ctx); err != nil { return err } return domain.ErrRefreshTokenRevoked } if time.Now().After(expiresAt) { return domain.ErrAuthTokenExpired } if next.UserID != userID { return errors.New("refresh token user mismatch") } ip := parseIP(next.IPAddress) if _, err := tx.Exec(ctx, ` INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address) VALUES ($1, $2, $3, NULLIF($4, ''), $5) `, next.Hash, next.UserID, next.ExpiresAt, next.UserAgent, ip); err != nil { return err } if _, err := tx.Exec(ctx, ` UPDATE refresh_tokens SET revoked_at = now(), replaced_by = $2 WHERE token_hash = $1 `, oldHash, next.Hash); err != nil { return err } return tx.Commit(ctx) } func (r *RefreshTokenRepo) Revoke(ctx context.Context, hash string) error { tag, err := r.pool.Exec(ctx, ` UPDATE refresh_tokens SET revoked_at = now() WHERE token_hash = $1 AND revoked_at IS NULL `, hash) if err != nil { return err } if tag.RowsAffected() == 0 { return domain.ErrAuthTokenNotFound } return nil } func (r *RefreshTokenRepo) RevokeAllForUser(ctx context.Context, userID uuid.UUID) error { _, err := r.pool.Exec(ctx, ` UPDATE refresh_tokens SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL `, userID) return err } func parseIP(s string) any { if s == "" { return nil } addr, err := netip.ParseAddr(s) if err != nil { return nil } return addr.String() } func classifyAuthTokenLookup(ctx context.Context, pool *pgxpool.Pool, q, hash string) error { var consumedAt *time.Time var expiresAt time.Time if err := pool.QueryRow(ctx, q, hash).Scan(&consumedAt, &expiresAt); err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.ErrAuthTokenNotFound } return err } if consumedAt != nil { return domain.ErrAuthTokenConsumed } if time.Now().After(expiresAt) { return domain.ErrAuthTokenExpired } return domain.ErrAuthTokenNotFound }