diff --git a/.gitignore b/.gitignore index 3fa7cc3..c827ef4 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ coverage.* .env .env.local +# Agent / per-developer config (launch.json with absolute paths, worktree state). +.claude/ + .DS_Store .idea/ .vscode/ diff --git a/cmd/api/main.go b/cmd/api/main.go index e2318a2..8d121b1 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -11,10 +11,15 @@ import ( "syscall" "time" + "github.com/redis/go-redis/v9" + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/auth" + "github.com/alchemistkay/guestguard/internal/billing" "github.com/alchemistkay/guestguard/internal/config" "github.com/alchemistkay/guestguard/internal/fraud" "github.com/alchemistkay/guestguard/internal/natspub" + "github.com/alchemistkay/guestguard/internal/notification" "github.com/alchemistkay/guestguard/internal/storage" ) @@ -56,6 +61,16 @@ func run() error { } defer natsClient.Close() + logger.Info("connecting to redis", "addr", cfg.RedisAddr) + rdb := redis.NewClient(&redis.Options{Addr: cfg.RedisAddr}) + if err := rdb.Ping(rootCtx).Err(); err != nil { + logger.Warn("redis ping failed — rate limits + lockout disabled", "err", err) + _ = rdb.Close() + rdb = nil + } else { + defer rdb.Close() + } + logger.Info("dialing fraud engine", "addr", cfg.FraudGRPCAddr) fraudClient, err := fraud.Dial(rootCtx, cfg.FraudGRPCAddr, cfg.FraudGRPCTimeout, logger) if err != nil { @@ -118,17 +133,93 @@ func run() error { } defer rsvpConsumeCtx.Stop() + // Notification senders. If SES creds are configured, route auth + + // guest emails through SES. Otherwise the log stub keeps the dev flow + // (verification link in API logs) intact. + tpls, err := notification.NewTemplates() + if err != nil { + return err + } + suppressions := notification.NewSuppressionRepo(db) + notifRepo := notification.NewRepo(db) + unsubSigner := notification.NewUnsubscribeSigner(cfg.UnsubscribeSecret) + + emailSenderCombined, backend, err := notification.PickEmailSender(rootCtx, notification.EmailSenderConfig{ + Resend: notification.ResendConfig{ + APIKey: cfg.ResendAPIKey, + FromEmail: cfg.ResendFromEmail, + FromName: cfg.ResendFromName, + }, + SMTP: notification.SMTPConfig{ + Host: cfg.SMTPHost, + Port: cfg.SMTPPort, + Username: cfg.SMTPUsername, + Password: cfg.SMTPPassword, + FromEmail: cfg.SMTPFromEmail, + FromName: cfg.SMTPFromName, + TLS: cfg.SMTPTLS, + }, + SES: notification.SESConfig{ + Region: cfg.SESRegion, + FromEmail: cfg.SESFromEmail, + FromName: cfg.SESFromName, + ConfigurationSet: cfg.SESConfigurationSet, + PublicBaseURL: cfg.PublicBaseURL, + }, + }, tpls, logger) + if err != nil { + return err + } + logger.Info("email backend selected", "backend", backend) + var emailSender auth.EmailSender = emailSenderCombined + + stripeClient, err := billing.NewClient(billing.Config{ + SecretKey: cfg.StripeSecretKey, + WebhookSecret: cfg.StripeWebhookSecret, + PriceProMonthly: cfg.StripePricePro, + PriceBusiness: cfg.StripePriceBusiness, + }) + if err != nil { + return err + } + if stripeClient != nil && stripeClient.Enabled() { + logger.Info("billing enabled via stripe") + } else { + logger.Info("billing disabled — free tier limits apply to all users") + } + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + Hub: hub, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + InvitationPublisher: natsClient, + FraudScorer: fraudClient, + TokenTTL: cfg.TokenTTL, + JWTSecret: cfg.JWTSecret, + JWTIssuer: cfg.JWTIssuer, + AccessTokenTTL: cfg.AccessTokenTTL, + RefreshTokenTTL: cfg.RefreshTokenTTL, + EmailVerificationTTL: cfg.EmailVerificationTTL, + PasswordResetTTL: cfg.PasswordResetTTL, + PublicBaseURL: cfg.PublicBaseURL, + RefreshCookieDomain: cfg.RefreshCookieDomain, + RefreshCookieSecure: cfg.RefreshCookieSecure, + Redis: rdb, + EmailSender: emailSender, + NotificationRepo: notifRepo, + SuppressionRepo: suppressions, + UnsubscribeSigner: unsubSigner, + StripeClient: stripeClient, + }) + if err != nil { + return err + } + srv := &http.Server{ - Addr: cfg.HTTPAddr, - Handler: api.NewServer(api.ServerDeps{ - Logger: logger, - DB: db, - Hub: hub, - AccessPublisher: natsClient, - RSVPPublisher: natsClient, - FraudScorer: fraudClient, - TokenTTL: cfg.TokenTTL, - }).Handler(), + Addr: cfg.HTTPAddr, + Handler: apiSrv.Handler(), ReadHeaderTimeout: 5 * time.Second, ReadTimeout: 30 * time.Second, WriteTimeout: 0, // 0 lets WS connections live; per-request handlers still bound by their own ctx diff --git a/cmd/notifier/main.go b/cmd/notifier/main.go index f095913..879acd6 100644 --- a/cmd/notifier/main.go +++ b/cmd/notifier/main.go @@ -48,7 +48,61 @@ func run() error { defer natsClient.Close() repo := notification.NewRepo(db) - sender := notification.LogSender{} + suppressions := notification.NewSuppressionRepo(db) + tpls, err := notification.NewTemplates() + if err != nil { + return err + } + + // Email dispatcher: Resend > SMTP > SES > log stub, same picker as cmd/api. + combinedEmail, backend, err := notification.PickEmailSender(rootCtx, notification.EmailSenderConfig{ + Resend: notification.ResendConfig{ + APIKey: cfg.ResendAPIKey, + FromEmail: cfg.ResendFromEmail, + FromName: cfg.ResendFromName, + }, + SMTP: notification.SMTPConfig{ + Host: cfg.SMTPHost, + Port: cfg.SMTPPort, + Username: cfg.SMTPUsername, + Password: cfg.SMTPPassword, + FromEmail: cfg.SMTPFromEmail, + FromName: cfg.SMTPFromName, + TLS: cfg.SMTPTLS, + }, + SES: notification.SESConfig{ + Region: cfg.SESRegion, + FromEmail: cfg.SESFromEmail, + FromName: cfg.SESFromName, + ConfigurationSet: cfg.SESConfigurationSet, + PublicBaseURL: cfg.PublicBaseURL, + }, + }, tpls, logger) + if err != nil { + return err + } + logger.Info("email backend selected", "backend", backend) + emailSender := notification.NewEmailSender(combinedEmail, suppressions) + + // SMS: Twilio when creds are set, otherwise no-op log sender. + var smsSender notification.Sender + if cfg.TwilioAccountSID != "" && cfg.TwilioAuthToken != "" && cfg.TwilioFromNumber != "" { + t, err := notification.NewTwilioSender(notification.TwilioConfig{ + AccountSID: cfg.TwilioAccountSID, + AuthToken: cfg.TwilioAuthToken, + FromNumber: cfg.TwilioFromNumber, + }) + if err != nil { + return err + } + smsSender = t + logger.Info("twilio sms sender configured", "from", cfg.TwilioFromNumber) + } else { + smsSender = notification.LogSender{} + logger.Info("twilio not configured — SMS will use the log stub") + } + + sender := notification.NewRouter(emailSender, smsSender) rsvpSub, err := natspub.NewRSVPConfirmedSubscriber( rootCtx, natsClient, "notifier-rsvp-confirmed", @@ -82,6 +136,22 @@ func run() error { } defer fraudCC.Stop() + invitationSub, err := natspub.NewInvitationSendSubscriber( + rootCtx, natsClient, "notifier-invitation-send", + func(ctx context.Context, evt natspub.InvitationSend) error { + return handleInvitationSend(ctx, logger, repo, sender, evt) + }, + logger, + ) + if err != nil { + return err + } + invitationCC, err := invitationSub.Start(rootCtx) + if err != nil { + return err + } + defer invitationCC.Stop() + logger.Info("notifier started") <-rootCtx.Done() logger.Info("notifier shutting down") @@ -201,6 +271,76 @@ func handleFraudScored( return nil } +// handleInvitationSend renders the invitation template + sends through +// the configured sender (Resend in prod, Mailpit in dev), then writes a +// notification row so the host can audit the delivery history. +func handleInvitationSend( + ctx context.Context, + logger *slog.Logger, + repo *notification.Repo, + sender notification.Sender, + evt natspub.InvitationSend, +) error { + if evt.GuestEmail == "" { + // Nothing to do — host-managed delivery. + return nil + } + + eventDate := "" + if !evt.EventDate.IsZero() { + eventDate = evt.EventDate.Format("Mon, 02 Jan 2006 · 15:04") + } + + msg := notification.OutboundMessage{ + GuestID: evt.GuestID, + Channel: notification.ChannelEmail, + Type: notification.TypeInvitation, + Subject: "You're invited — " + evt.EventName, + Metadata: map[string]any{ + "to": evt.GuestEmail, + "GuestName": evt.GuestName, + "HostName": evt.HostName, + "EventName": evt.EventName, + "Venue": evt.Venue, + "EventDate": eventDate, + "Link": evt.Link, + }, + } + + providerID, sendErr := sender.Send(ctx, msg) + status := notification.StatusSent + errStr := "" + if sendErr != nil { + status = notification.StatusFailed + errStr = sendErr.Error() + } + + id, err := repo.Record(ctx, notification.RecordParams{ + GuestID: evt.GuestID, + Channel: msg.Channel, + Type: msg.Type, + Status: status, + ProviderMessageID: providerID, + Error: errStr, + }) + if err != nil { + return err + } + + logger.Info("invitation dispatched", + "notification_id", id, + "guest_id", evt.GuestID, + "event_id", evt.EventID, + "to", evt.GuestEmail, + "status", status, + "provider_message_id", providerID, + ) + if sendErr != nil { + return sendErr + } + return nil +} + func levelFor(env string) slog.Level { if env == "development" { return slog.LevelDebug diff --git a/cmd/restore-verify/main.go b/cmd/restore-verify/main.go new file mode 100644 index 0000000..2dddaf4 --- /dev/null +++ b/cmd/restore-verify/main.go @@ -0,0 +1,267 @@ +// restore-verify is a post-restore sanity tool. Point it at a freshly +// restored Postgres instance and it asserts that the schema and data +// are coherent — no orphan rows, expected uniqueness invariants hold, +// every table is present. Exits non-zero on any failure so it slots +// into a "restore drill" runbook step. +// +// Usage: +// GG_DATABASE_URL=postgres://... ./restore-verify [--verbose] +// +// The intent is "would I bet my Sunday on this restore being usable?". +// Failing fast here keeps a bad restore from being promoted to traffic. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "os" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgxpool" +) + +func main() { + verbose := flag.Bool("verbose", false, "print every check's result, not just failures") + flag.Parse() + + dsn := os.Getenv("GG_DATABASE_URL") + if dsn == "" { + fmt.Fprintln(os.Stderr, "restore-verify: GG_DATABASE_URL is required") + os.Exit(2) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + fmt.Fprintf(os.Stderr, "restore-verify: connect: %v\n", err) + os.Exit(2) + } + defer pool.Close() + + if err := pool.Ping(ctx); err != nil { + fmt.Fprintf(os.Stderr, "restore-verify: ping: %v\n", err) + os.Exit(2) + } + + fmt.Println("restore-verify: checking", maskDSN(dsn)) + fmt.Println() + + checks := allChecks() + var failed []string + for _, c := range checks { + result, err := c.fn(ctx, pool) + if err != nil { + failed = append(failed, c.name) + fmt.Printf(" ✗ %-50s FAIL: %v\n", c.name, err) + continue + } + if *verbose { + fmt.Printf(" ✓ %-50s %s\n", c.name, result) + } + } + + fmt.Println() + if len(failed) > 0 { + fmt.Printf("FAILED: %d check%s — %s\n", + len(failed), pluralS(len(failed)), strings.Join(failed, ", ")) + os.Exit(1) + } + fmt.Printf("OK: all %d checks passed\n", len(checks)) +} + +type check struct { + name string + fn func(ctx context.Context, pool *pgxpool.Pool) (string, error) +} + +func allChecks() []check { + return []check{ + // --- schema presence --- + tableExists("users"), + tableExists("events"), + tableExists("guests"), + tableExists("tokens"), + tableExists("rsvps"), + tableExists("access_logs"), + tableExists("notifications"), + tableExists("schema_migrations"), + tableExists("email_verification_tokens"), + tableExists("password_reset_tokens"), + tableExists("refresh_tokens"), + tableExists("unsubscribes"), + tableExists("subscriptions"), + + // --- migrations applied --- + { + name: "schema_migrations: 5+ migrations applied", + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + var n int + if err := pool.QueryRow(ctx, + `SELECT count(*) FROM schema_migrations`).Scan(&n); err != nil { + return "", err + } + if n < 5 { + return "", fmt.Errorf("only %d migrations recorded — incomplete restore", n) + } + return fmt.Sprintf("%d migrations", n), nil + }, + }, + + // --- referential integrity (FK constraints catch most of this, + // but a bad logical dump or partial restore can sneak rows in) --- + noOrphans("events", "host_id", "users", "id"), + noOrphans("guests", "event_id", "events", "id"), + noOrphans("tokens", "guest_id", "guests", "id"), + noOrphans("rsvps", "guest_id", "guests", "id"), + noOrphans("access_logs", "guest_id", "guests", "id"), + noOrphans("notifications", "guest_id", "guests", "id"), + noOrphans("email_verification_tokens", "user_id", "users", "id"), + noOrphans("password_reset_tokens", "user_id", "users", "id"), + noOrphans("refresh_tokens", "user_id", "users", "id"), + noOrphans("subscriptions", "user_id", "users", "id"), + + // --- domain invariants --- + { + name: "users.email is unique (case-insensitive)", + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + var dupes int + err := pool.QueryRow(ctx, ` + SELECT count(*) FROM ( + SELECT lower(email) FROM users GROUP BY lower(email) HAVING count(*) > 1 + ) t + `).Scan(&dupes) + if err != nil { + return "", err + } + if dupes > 0 { + return "", fmt.Errorf("%d duplicate email(s) found", dupes) + } + return "no duplicate emails", nil + }, + }, + { + name: "guests with rsvp_response have an existing rsvp row", + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + var n int + err := pool.QueryRow(ctx, ` + SELECT count(*) FROM rsvps r + WHERE NOT EXISTS (SELECT 1 FROM guests g WHERE g.id = r.guest_id) + `).Scan(&n) + if err != nil { + return "", err + } + if n > 0 { + return "", fmt.Errorf("%d rsvp(s) reference a missing guest", n) + } + return "0 orphan rsvps", nil + }, + }, + { + name: "at most one granting subscription per user", + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + var n int + err := pool.QueryRow(ctx, ` + SELECT count(*) FROM ( + SELECT user_id FROM subscriptions + WHERE status IN ('active','past_due','trialing') + GROUP BY user_id HAVING count(*) > 1 + ) t + `).Scan(&n) + if err != nil { + return "", err + } + if n > 0 { + return "", fmt.Errorf("%d user(s) have multiple granting subscriptions", n) + } + return "all users single-active", nil + }, + }, + + // --- soft constraints worth noticing (not failures, but logged) --- + { + name: "row counts snapshot", + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + parts := []string{} + for _, t := range []string{ + "users", "events", "guests", "tokens", "rsvps", + "access_logs", "notifications", "subscriptions", + } { + var n int + if err := pool.QueryRow(ctx, + fmt.Sprintf("SELECT count(*) FROM %s", t)).Scan(&n); err != nil { + return "", err + } + parts = append(parts, fmt.Sprintf("%s=%d", t, n)) + } + return strings.Join(parts, " "), nil + }, + }, + } +} + +func tableExists(name string) check { + return check{ + name: fmt.Sprintf("table %q exists", name), + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + var exists bool + err := pool.QueryRow(ctx, + `SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema='public' AND table_name=$1)`, + name, + ).Scan(&exists) + if err != nil { + return "", err + } + if !exists { + return "", errors.New("missing") + } + return "ok", nil + }, + } +} + +func noOrphans(childTable, childFK, parentTable, parentPK string) check { + return check{ + name: fmt.Sprintf("no orphans: %s.%s -> %s.%s", childTable, childFK, parentTable, parentPK), + fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { + q := fmt.Sprintf(` + SELECT count(*) FROM %s c + WHERE c.%s IS NOT NULL + AND NOT EXISTS (SELECT 1 FROM %s p WHERE p.%s = c.%s) + `, childTable, childFK, parentTable, parentPK, childFK) + var n int + if err := pool.QueryRow(ctx, q).Scan(&n); err != nil { + return "", err + } + if n > 0 { + return "", fmt.Errorf("%d orphan row(s)", n) + } + return "clean", nil + }, + } +} + +func maskDSN(dsn string) string { + // Crude: redact password between '://user:' and '@'. + at := strings.LastIndex(dsn, "@") + scheme := strings.Index(dsn, "://") + if at < 0 || scheme < 0 || at <= scheme { + return dsn + } + userInfo := dsn[scheme+3 : at] + if colon := strings.Index(userInfo, ":"); colon >= 0 { + userInfo = userInfo[:colon] + ":****" + } + return dsn[:scheme+3] + userInfo + dsn[at:] +} + +func pluralS(n int) string { + if n == 1 { + return "" + } + return "s" +} diff --git a/docker-compose.yml b/docker-compose.yml index ad086cb..8a0d440 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,28 @@ services: timeout: 3s retries: 10 + mailpit: + image: axllent/mailpit:latest + ports: + - "1025:1025" # SMTP + - "8025:8025" # Web UI + healthcheck: + test: ["CMD", "wget", "-qO-", "http://localhost:8025/api/v1/info"] + interval: 5s + timeout: 3s + retries: 10 + + redis: + image: redis:7-alpine + command: ["redis-server", "--save", "", "--appendonly", "no"] + ports: + - "6379:6379" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 3s + retries: 10 + nats: image: nats:2.10-alpine command: @@ -45,7 +67,25 @@ services: GG_NATS_URL: nats://nats:4222 GG_FRAUD_GRPC_ADDR: fraud-engine:9091 GG_FRAUD_GRPC_TIMEOUT: 250ms + GG_REDIS_ADDR: redis:6379 + GG_SMTP_HOST: mailpit + GG_SMTP_PORT: "1025" + GG_SMTP_FROM_EMAIL: noreply@guestguard.local + GG_SMTP_FROM_NAME: GuestGuard (dev) + # Resend overrides SMTP when GG_RESEND_API_KEY is set in .env at the + # project root. Leave unset and Mailpit (above) handles delivery. + GG_RESEND_API_KEY: ${GG_RESEND_API_KEY:-} + GG_RESEND_FROM_EMAIL: ${GG_RESEND_FROM_EMAIL:-} + GG_RESEND_FROM_NAME: ${GG_RESEND_FROM_NAME:-GuestGuard} GG_TOKEN_SECRET: dev-only-insecure-secret-change-me + GG_JWT_SECRET: dev-only-insecure-jwt-secret-change-me-32+bytes + GG_PUBLIC_BASE_URL: http://localhost:3000 + # Stripe billing — empty values leave billing disabled. Set in + # .env at the project root to enable. + GG_STRIPE_SECRET_KEY: ${GG_STRIPE_SECRET_KEY:-} + GG_STRIPE_WEBHOOK_SECRET: ${GG_STRIPE_WEBHOOK_SECRET:-} + GG_STRIPE_PRICE_PRO: ${GG_STRIPE_PRICE_PRO:-} + GG_STRIPE_PRICE_BUSINESS: ${GG_STRIPE_PRICE_BUSINESS:-} ports: - "8080:8080" depends_on: @@ -53,6 +93,8 @@ services: condition: service_healthy nats: condition: service_healthy + redis: + condition: service_healthy restart: unless-stopped fraud-engine: @@ -80,6 +122,14 @@ services: GG_ENV: development GG_DATABASE_URL: postgres://guestguard:guestguard@postgres:5432/guestguard?sslmode=disable GG_NATS_URL: nats://nats:4222 + GG_PUBLIC_BASE_URL: http://localhost:3000 + GG_SMTP_HOST: mailpit + GG_SMTP_PORT: "1025" + GG_SMTP_FROM_EMAIL: noreply@guestguard.local + GG_SMTP_FROM_NAME: GuestGuard (dev) + GG_RESEND_API_KEY: ${GG_RESEND_API_KEY:-} + GG_RESEND_FROM_EMAIL: ${GG_RESEND_FROM_EMAIL:-} + GG_RESEND_FROM_NAME: ${GG_RESEND_FROM_NAME:-GuestGuard} depends_on: postgres: condition: service_healthy diff --git a/docs/RUNBOOK_RESTORE.md b/docs/RUNBOOK_RESTORE.md new file mode 100644 index 0000000..cc98b46 --- /dev/null +++ b/docs/RUNBOOK_RESTORE.md @@ -0,0 +1,251 @@ +# Runbook — Postgres restore + +This is the procedure to bring GuestGuard back from a Postgres backup +after data loss. It assumes the infra side of Block G (`pg_basebackup` + +WAL archiving to S3, daily logical dumps, cross-region replication) is +already in place — see the homelab repo for those. + +The application side — migration down-scripts, the [`restore-verify`](../cmd/restore-verify/main.go) +tool, and this document — lives here in the GuestGuard repo so it ships +in lockstep with the schema. + +--- + +## Targets + +| Metric | Target | +|---|---| +| RTO (recovery time objective) | ≤ 1 hour from "go" decision to traffic-serving | +| RPO (recovery point objective) | ≤ 5 minutes of data loss (WAL ships every 60s, S3 PUT every 5min) | + +If RTO is going to slip past 1 hour, escalate per the comms plan in `docs/INCIDENT_RESPONSE.md` (infra repo). + +## When to invoke this + +- Primary Postgres is unreachable AND the standby has also failed +- Logical corruption discovered (e.g., a bad migration deleted rows) +- Region-wide outage at the primary's location +- A "what if we restored last Tuesday" drill (see [Drill](#drill-procedure)) + +If only the primary is unreachable and the standby is healthy, promote +the standby (separate runbook). Don't use this procedure unnecessarily — +restores are expensive. + +## Prerequisites + +Before starting: + +- [ ] Decision authority has approved the restore (CTO or on-call lead) +- [ ] Read access to the S3 backup bucket: `s3://guestguard-pg-backups` +- [ ] `psql`, `pg_basebackup`, `wal-g` (or chosen WAL tool) installed +- [ ] Empty target Postgres instance provisioned (Kubernetes Statefulset, + RDS, or homelab box — same major version as the backup) +- [ ] `GG_DATABASE_URL` env var ready for the new instance +- [ ] Maintenance page deployed to the frontend (`/dashboard` returns 503) +- [ ] API + notifier pods scaled to 0 (`kubectl scale --replicas=0`) +- [ ] This document open in another tab + +## Steps + +### 1. Stop write traffic + +```bash +# k8s +kubectl scale deployment/guestguard-api --replicas=0 +kubectl scale deployment/guestguard-notifier --replicas=0 + +# Confirm no connections to the (broken) primary +kubectl exec -n postgres guestguard-pg-0 -- psql -U postgres -c \ + "SELECT count(*) FROM pg_stat_activity WHERE datname='guestguard'" +``` + +If using docker-compose locally: `docker compose stop api notifier`. + +### 2. Identify the recovery point + +Pick the latest backup that's known-good. For corruption scenarios, +this may mean going further back than the most recent dump. + +```bash +# List base backups (most recent first) +wal-g backup-list 2>/dev/null | tail -10 + +# Pick the timestamp (e.g. base_000000010000000000000007) and decide +# the LSN target if doing point-in-time recovery +``` + +For corruption: pick the latest backup created **before** the corrupting +event. For "ransomware / bad migration", probably 1–2 days back. + +### 3. Restore the base backup + +```bash +# Replace BACKUP_NAME with the chosen base +wal-g backup-fetch /var/lib/postgresql/data BACKUP_NAME + +# Configure recovery target (omit recovery_target_time for "latest") +cat >> /var/lib/postgresql/data/postgresql.conf < -const { host, clear } = useHost() +const auth = useAuth() const route = useRoute() // GitHub icon is a "marketing" affordance — only show it on the public landing // page. Inside the app it just clutters the chrome. const showGithub = computed(() => route.path === '/') -function logout() { - clear() +async function signOut() { + await auth.logout() navigateTo('/') } @@ -34,12 +34,17 @@ function logout() { - - + + + + @@ -48,9 +53,22 @@ function logout() { + + + + + +
-
- © 2025 GuestGuard — Hassle-free RSVPs for every occasion. +
+ © 2025 GuestGuard — Hassle-free RSVPs for every occasion. + + Privacy + Terms +
diff --git a/frontend/components/CsvImportCard.vue b/frontend/components/CsvImportCard.vue new file mode 100644 index 0000000..9955407 --- /dev/null +++ b/frontend/components/CsvImportCard.vue @@ -0,0 +1,217 @@ + + + diff --git a/frontend/components/PhoneInput.vue b/frontend/components/PhoneInput.vue new file mode 100644 index 0000000..4972584 --- /dev/null +++ b/frontend/components/PhoneInput.vue @@ -0,0 +1,240 @@ + + + diff --git a/frontend/components/TermsGateModal.vue b/frontend/components/TermsGateModal.vue new file mode 100644 index 0000000..5c46dd8 --- /dev/null +++ b/frontend/components/TermsGateModal.vue @@ -0,0 +1,80 @@ + + + diff --git a/frontend/components/UpgradeModal.vue b/frontend/components/UpgradeModal.vue new file mode 100644 index 0000000..5a0af5b --- /dev/null +++ b/frontend/components/UpgradeModal.vue @@ -0,0 +1,131 @@ + + + diff --git a/frontend/composables/useApi.ts b/frontend/composables/useApi.ts index 1be9f12..9cd800d 100644 --- a/frontend/composables/useApi.ts +++ b/frontend/composables/useApi.ts @@ -1,16 +1,57 @@ // Typed wrapper around $fetch with the configured API base. -// Usage: const events = await useApi('/events') +// +// Adds `Authorization: Bearer ` when the caller is signed in, +// and on a 401 transparently asks `/auth/refresh` for a new token and retries +// once. Failed refresh clears local auth state — pages can rely on the +// returned error to redirect to /login. + export async function useApi( path: string, opts: { method?: string; body?: unknown; query?: Record } = {}, ): Promise { const config = useRuntimeConfig() const base = config.public.apiBase as string - return await $fetch(path, { - baseURL: base, - method: (opts.method ?? 'GET') as any, - body: opts.body, - query: opts.query, - headers: { 'Content-Type': 'application/json' }, - }) + const auth = useAuth() + + const request = async (token: string | null): Promise => { + const headers: Record = {} + // Let the browser set Content-Type (with the multipart boundary) when + // the body is FormData / Blob; otherwise default to JSON. + const isMultipart = typeof FormData !== 'undefined' && opts.body instanceof FormData + if (!isMultipart) headers['Content-Type'] = 'application/json' + if (token) headers.Authorization = `Bearer ${token}` + return await $fetch(path, { + baseURL: base, + method: (opts.method ?? 'GET') as any, + body: opts.body as any, + query: opts.query, + headers, + credentials: 'include', + }) + } + + try { + return await request(auth.liveAccessToken()) + } catch (err: any) { + const status = err?.response?.status ?? err?.statusCode + + // 402 Payment Required — plan limit hit. Surface the backend's + // upgrade payload on a global state slot; the UpgradeModal in + // app.vue reads it and prompts the host to upgrade. We still + // rethrow so the caller can stop its own UI flow if it wants. + if (status === 402) { + const data = err?.data + if (data && data.upgrade_url) { + useBilling().showUpgradePrompt(data) + } + throw err + } + + if (status !== 401) throw err + // /auth/* endpoints set the cookie themselves — never retry-refresh them. + if (path.startsWith('/auth/')) throw err + const refreshed = await auth.refresh() + if (!refreshed) throw err + return await request(auth.liveAccessToken()) + } } diff --git a/frontend/composables/useAuth.ts b/frontend/composables/useAuth.ts new file mode 100644 index 0000000..77275c8 --- /dev/null +++ b/frontend/composables/useAuth.ts @@ -0,0 +1,147 @@ +// Auth state for the host-facing app. +// +// The access token lives only in memory (useState — Nuxt's SSR-safe wrapper). +// The refresh token lives in an HttpOnly cookie set by the API at +// `/auth/refresh` scope, so JavaScript here can never read it. On a hard +// reload we lose the access token but the cookie survives, so `bootstrap()` +// calls `/auth/refresh` to mint a fresh access token + reload the user. + +interface AuthUser { + id: string + email: string + name: string + email_verified: boolean +} + +interface AuthSuccess { + access_token: string + expires_at: string + user: AuthUser +} + +interface AuthState { + user: AuthUser | null + accessToken: string | null + expiresAt: number | null // unix ms + bootstrapped: boolean +} + +function emptyState(): AuthState { + return { user: null, accessToken: null, expiresAt: null, bootstrapped: false } +} + +function apiBase(): string { + return useRuntimeConfig().public.apiBase as string +} + +async function postJSON(path: string, body?: unknown): Promise { + return await $fetch(path, { + baseURL: apiBase(), + method: 'POST', + body, + credentials: 'include', + headers: { 'Content-Type': 'application/json' }, + }) +} + +export function useAuth() { + const state = useState('gg-auth', emptyState) + + function setSession(s: AuthSuccess) { + state.value = { + user: s.user, + accessToken: s.access_token, + expiresAt: Date.parse(s.expires_at) || (Date.now() + 14 * 60 * 1000), + bootstrapped: true, + } + } + + function clearSession() { + state.value = { ...emptyState(), bootstrapped: true } + } + + async function signup(email: string, name: string, password: string, acceptTerms = false) { + return await postJSON<{ status: string }>('/auth/signup', { + email, name, password, + accept_terms: acceptTerms, + }) + } + + async function login(email: string, password: string) { + const s = await postJSON('/auth/login', { email, password }) + setSession(s) + return s + } + + async function refresh(): Promise { + try { + const s = await postJSON('/auth/refresh') + setSession(s) + return true + } catch { + clearSession() + return false + } + } + + async function logout() { + try { + await postJSON('/auth/logout') + } catch { + // Best-effort — clear local state regardless. + } + clearSession() + } + + async function verifyEmail(token: string) { + return await postJSON<{ status: string }>('/auth/verify-email', { token }) + } + + async function forgotPassword(email: string) { + return await postJSON<{ status: string }>('/auth/forgot-password', { email }) + } + + async function resetPassword(token: string, newPassword: string) { + return await postJSON<{ status: string }>('/auth/reset-password', { + token, + new_password: newPassword, + }) + } + + // Call on app entry / route guards. Returns true if the caller has a valid + // session by the time it resolves. + async function bootstrap(): Promise { + if (!import.meta.client) return false + if (state.value.bootstrapped && state.value.user) return true + if (state.value.bootstrapped && !state.value.user) return false + return await refresh() + } + + // Hint to useApi: returns the current token if not yet expired. + function liveAccessToken(): string | null { + if (!state.value.accessToken || !state.value.expiresAt) return null + // 5s skew to avoid sending a just-expired token. + if (Date.now() + 5000 >= state.value.expiresAt) return null + return state.value.accessToken + } + + const isAuthenticated = computed(() => !!state.value.user) + const user = computed(() => state.value.user) + const bootstrapped = computed(() => state.value.bootstrapped) + + return { + user, + isAuthenticated, + bootstrapped, + signup, + login, + refresh, + logout, + verifyEmail, + forgotPassword, + resetPassword, + bootstrap, + liveAccessToken, + clearSession, + } +} diff --git a/frontend/composables/useBilling.ts b/frontend/composables/useBilling.ts new file mode 100644 index 0000000..5687c66 --- /dev/null +++ b/frontend/composables/useBilling.ts @@ -0,0 +1,151 @@ +// Billing client: fetches subscription status, kicks off checkout/portal +// flows, and owns the global "upgrade required" prompt shown by the +// 402 interceptor in useApi. State is shared across components via +// useState so the prompt can be triggered from any handler. + +export interface BillingStatus { + tier: 'free' | 'pro' | 'business' + status: string + current_period_end?: string + cancel_at_period_end: boolean + limits: { + events_per_month: number + guests_per_event: number + } + usage: { + events_this_month: number + } + portal_available: boolean +} + +// UpgradePrompt mirrors the 402 body the backend returns when a limit is +// hit. Shown globally as a modal until dismissed or acted upon. +export interface UpgradePrompt { + error: string + reason: string + tier: string + used: number + limit: number + upgrade_url: string +} + +const FREE_DEFAULT: BillingStatus = { + tier: 'free', + status: 'active', + cancel_at_period_end: false, + limits: { events_per_month: 1, guests_per_event: 50 }, + usage: { events_this_month: 0 }, + portal_available: false, +} + +export function useBilling() { + const status = useState('gg-billing-status', () => null) + const loading = useState('gg-billing-loading', () => false) + const prompt = useState('gg-upgrade-prompt', () => null) + + async function fetchStatus(): Promise { + loading.value = true + try { + const res = await useApi('/billing/status') + status.value = res + return res + } catch (e: any) { + // 401/refresh edge → caller redirected to /login by useApi. If the + // backend has billing wired but the response is malformed, fall + // back to free defaults so the page renders something usable. + status.value = FREE_DEFAULT + return FREE_DEFAULT + } finally { + loading.value = false + } + } + + async function startCheckout(tier: 'pro' | 'business'): Promise { + const res = await useApi<{ url: string }>('/billing/checkout-session', { + method: 'POST', + body: { tier }, + }) + if (import.meta.client) window.location.href = res.url + } + + async function openPortal(): Promise { + const res = await useApi<{ url: string }>('/billing/portal', { method: 'POST' }) + if (import.meta.client) window.location.href = res.url + } + + function showUpgradePrompt(info: UpgradePrompt) { + prompt.value = info + } + function dismissUpgradePrompt() { + prompt.value = null + } + + return { + status, + loading, + prompt, + fetchStatus, + startCheckout, + openPortal, + showUpgradePrompt, + dismissUpgradePrompt, + } +} + +// Static pricing copy. Keep in sync with internal/billing/tiers.go. +// One source of truth for the marketing-page-style cards in +// /dashboard/billing and the UpgradeModal. +export interface TierCard { + id: 'free' | 'pro' | 'business' + name: string + price: string + priceSubtitle: string + tagline: string + features: string[] + highlight?: boolean +} + +export const TIER_CARDS: TierCard[] = [ + { + id: 'free', + name: 'Free', + price: '$0', + priceSubtitle: 'forever', + tagline: 'Try GuestGuard with a single event.', + features: [ + '1 event per month', + 'Up to 50 guests per event', + 'Branded email invitations', + 'Real-time RSVP dashboard', + ], + }, + { + id: 'pro', + name: 'Pro', + price: '$49', + priceSubtitle: 'per month', + tagline: 'For active hosts running several events.', + features: [ + '10 events per month', + 'Up to 1,000 guests per event', + 'WhatsApp + email invitations', + 'CSV import + bulk send', + 'Priority email support', + ], + highlight: true, + }, + { + id: 'business', + name: 'Business', + price: '$199', + priceSubtitle: 'per month', + tagline: 'For agencies and corporate events teams.', + features: [ + 'Unlimited events', + 'Up to 5,000 guests per event', + 'Everything in Pro', + 'Signed DPA on request', + 'SLA with response targets', + ], + }, +] diff --git a/frontend/composables/useErrMessage.ts b/frontend/composables/useErrMessage.ts new file mode 100644 index 0000000..347634e --- /dev/null +++ b/frontend/composables/useErrMessage.ts @@ -0,0 +1,28 @@ +// Friendly error messages for API failures, with first-class handling of +// 429 (rate-limited) and 403 account-lockout responses. +export function useErrMessage(e: any, fallback = 'Something went wrong'): string { + const status: number | undefined = e?.response?.status ?? e?.statusCode + const data = e?.data + const serverMsg: string | undefined = data?.error + + if (status === 429) { + const retry: number | undefined = data?.retry_after + if (typeof retry === 'number' && retry > 0) { + return `You're going too fast — try again in ${formatSeconds(retry)}.` + } + return "You're going too fast — please try again in a moment." + } + + if (status === 403 && typeof serverMsg === 'string' && serverMsg.toLowerCase().includes('locked')) { + return 'Your account is locked after too many failed sign-in attempts. Reset your password to unlock it.' + } + + if (serverMsg) return serverMsg + return e?.message || fallback +} + +function formatSeconds(s: number): string { + if (s < 60) return `${s} second${s === 1 ? '' : 's'}` + const m = Math.ceil(s / 60) + return `${m} minute${m === 1 ? '' : 's'}` +} diff --git a/frontend/composables/useEventWS.ts b/frontend/composables/useEventWS.ts index 89461b0..731d044 100644 --- a/frontend/composables/useEventWS.ts +++ b/frontend/composables/useEventWS.ts @@ -1,7 +1,12 @@ // Subscribes to /ws/events/:id and emits per-message callbacks. // -// Auto-reconnects with exponential backoff up to 30s. Returns a cleanup -// fn the caller invokes (e.g. inside onUnmounted). +// Authenticates via short-lived ticket: before each connect we POST +// /auth/ws-ticket (bearer-authed) to mint a one-shot ticket, then pass it +// on the WS handshake as `?ticket=…`. Tickets expire ~60s after mint, so we +// always mint fresh — even on reconnects. +// +// Auto-reconnects with exponential backoff up to 30s. Returns a cleanup fn +// the caller invokes (e.g. inside onUnmounted). interface WSMessage { type: string @@ -10,21 +15,46 @@ interface WSMessage { timestamp: string } +interface WSTicket { + ticket: string + expires_at: string +} + export function useEventWS(eventId: string, onMessage: (msg: WSMessage) => void) { if (import.meta.server) return () => {} const config = useRuntimeConfig() - const base = (config.public.wsBase as string) || '' - const url = `${base}/ws/events/${eventId}` + const wsBase = (config.public.wsBase as string) || '' let ws: WebSocket | null = null let attempt = 0 let stopped = false let reconnectTimer: ReturnType | null = null - function connect() { + async function mintTicket(): Promise { + try { + const t = await useApi('/auth/ws-ticket', { + method: 'POST', + body: { event_id: eventId }, + }) + return t.ticket + } catch { + return null + } + } + + async function connect() { if (stopped) return - ws = new WebSocket(url) + const ticket = await mintTicket() + if (stopped) return + if (!ticket) { + // Couldn't get a ticket (likely 401 — session expired). Back off and + // retry; useApi will have already attempted refresh on its own. + const backoff = Math.min(30_000, 500 * Math.pow(2, attempt++)) + reconnectTimer = setTimeout(connect, backoff) + return + } + ws = new WebSocket(`${wsBase}/ws/events/${eventId}?ticket=${encodeURIComponent(ticket)}`) ws.onopen = () => { attempt = 0 diff --git a/frontend/composables/useHost.ts b/frontend/composables/useHost.ts deleted file mode 100644 index b725aea..0000000 --- a/frontend/composables/useHost.ts +++ /dev/null @@ -1,43 +0,0 @@ -// Demo-grade host bootstrap. Real auth would replace this entirely; for now -// we upsert by email and stash the host id in localStorage. -interface User { - id: string - email: string - name: string -} - -const STORAGE_KEY = 'gg.host' - -export function useHost() { - const host = useState('gg-host', () => null) - - if (import.meta.client && !host.value) { - const raw = window.localStorage.getItem(STORAGE_KEY) - if (raw) { - try { - host.value = JSON.parse(raw) - } catch { - window.localStorage.removeItem(STORAGE_KEY) - } - } - } - - async function bootstrap(email: string, name: string) { - const u = await useApi('/users', { - method: 'POST', - body: { email, name }, - }) - host.value = u - if (import.meta.client) { - window.localStorage.setItem(STORAGE_KEY, JSON.stringify(u)) - } - return u - } - - function clear() { - host.value = null - if (import.meta.client) window.localStorage.removeItem(STORAGE_KEY) - } - - return { host, bootstrap, clear } -} diff --git a/frontend/middleware/auth.ts b/frontend/middleware/auth.ts new file mode 100644 index 0000000..5ffe809 --- /dev/null +++ b/frontend/middleware/auth.ts @@ -0,0 +1,14 @@ +// Route guard: bootstraps the session (server -> client cookie roundtrip, +// then optional /auth/refresh), and redirects to /login if no session. +// +// Skip on SSR: auth state lives entirely in the browser. Dashboard pages +// render their own client-side loading state while we resolve. + +export default defineNuxtRouteMiddleware(async (to) => { + if (import.meta.server) return + const auth = useAuth() + const ok = await auth.bootstrap() + if (!ok) { + return navigateTo({ path: '/login', query: { redirect: to.fullPath } }) + } +}) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 843d521..33f2092 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -5092,12 +5092,14 @@ } }, "node_modules/commander": { - "version": "10.0.1", - "resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz", - "integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==", + "version": "13.1.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-13.1.0.tgz", + "integrity": "sha512-/rFeCpNJQbhSZjGVwO9RFV3xPqbnERS8MmIQzCtD/zl6gpJuV/bMLuN92oG3F7d8oDEHHRrujSXNUr8fpjntKw==", "license": "MIT", + "optional": true, + "peer": true, "engines": { - "node": ">=14" + "node": ">=18" } }, "node_modules/common-path-prefix": { @@ -8500,6 +8502,15 @@ "node": ">=8" } }, + "node_modules/lambda-local/node_modules/commander": { + "version": "10.0.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz", + "integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==", + "license": "MIT", + "engines": { + "node": ">=14" + } + }, "node_modules/lambda-local/node_modules/dotenv": { "version": "16.6.1", "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.6.1.tgz", diff --git a/frontend/pages/dashboard/billing.vue b/frontend/pages/dashboard/billing.vue new file mode 100644 index 0000000..4d96355 --- /dev/null +++ b/frontend/pages/dashboard/billing.vue @@ -0,0 +1,393 @@ + + + diff --git a/frontend/pages/dashboard/events/[id].vue b/frontend/pages/dashboard/events/[id].vue index d5ea989..0db499a 100644 --- a/frontend/pages/dashboard/events/[id].vue +++ b/frontend/pages/dashboard/events/[id].vue @@ -1,4 +1,6 @@ + + diff --git a/frontend/pages/login.vue b/frontend/pages/login.vue new file mode 100644 index 0000000..7878fd5 --- /dev/null +++ b/frontend/pages/login.vue @@ -0,0 +1,53 @@ + + + diff --git a/frontend/pages/privacy.vue b/frontend/pages/privacy.vue new file mode 100644 index 0000000..8871a41 --- /dev/null +++ b/frontend/pages/privacy.vue @@ -0,0 +1,51 @@ + + + diff --git a/frontend/pages/reset-password/[token].vue b/frontend/pages/reset-password/[token].vue new file mode 100644 index 0000000..b5a419a --- /dev/null +++ b/frontend/pages/reset-password/[token].vue @@ -0,0 +1,73 @@ + + + diff --git a/frontend/pages/signup.vue b/frontend/pages/signup.vue new file mode 100644 index 0000000..5ea4fa7 --- /dev/null +++ b/frontend/pages/signup.vue @@ -0,0 +1,90 @@ + + + diff --git a/frontend/pages/terms.vue b/frontend/pages/terms.vue new file mode 100644 index 0000000..6383d82 --- /dev/null +++ b/frontend/pages/terms.vue @@ -0,0 +1,50 @@ + + + diff --git a/frontend/pages/unsubscribe/[token].vue b/frontend/pages/unsubscribe/[token].vue new file mode 100644 index 0000000..7729067 --- /dev/null +++ b/frontend/pages/unsubscribe/[token].vue @@ -0,0 +1,69 @@ + + + diff --git a/frontend/pages/verify-email.vue b/frontend/pages/verify-email.vue new file mode 100644 index 0000000..5061156 --- /dev/null +++ b/frontend/pages/verify-email.vue @@ -0,0 +1,44 @@ + + + diff --git a/go.mod b/go.mod index 8e91857..ad1ae84 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,13 @@ go 1.26.2 require ( github.com/coder/websocket v1.8.14 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.9.2 github.com/nats-io/nats.go v1.52.0 github.com/testcontainers/testcontainers-go v0.42.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 + golang.org/x/crypto v0.49.0 google.golang.org/grpc v1.81.0 google.golang.org/protobuf v1.36.11 ) @@ -17,6 +19,22 @@ require ( dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/alicebob/miniredis/v2 v2.38.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect + github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect + github.com/aws/smithy-go v1.25.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/errdefs v1.0.0 // indirect @@ -33,6 +51,7 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/golang/mock v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect @@ -52,24 +71,29 @@ require ( github.com/nats-io/nuid v1.0.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect + github.com/redis/go-redis/v9 v9.19.0 // indirect github.com/shirou/gopsutil/v4 v4.26.3 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/stretchr/testify v1.11.1 // indirect + github.com/stripe/stripe-go/v82 v82.5.1 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/numcpus v0.11.0 // indirect + github.com/twilio/twilio-go v1.30.9 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect go.opentelemetry.io/otel v1.43.0 // indirect go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect - golang.org/x/crypto v0.49.0 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/net v0.51.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect - golang.org/x/text v0.35.0 // indirect + golang.org/x/text v0.37.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index de71ff3..7793495 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,38 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw= +github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8= +github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc= +github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU= +github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU= +github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA= +github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4 h1:X/PtmuX/EwPivJ9lHCf3Auo8AktdNc4a9ury4zmGPC4= +github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4/go.mod h1:l5cTwZSX9kzxDHz9IpgZC0XIJ/cc43JL6hZzCd0iTwI= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio= +github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI= +github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -44,6 +76,10 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -101,10 +137,14 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/shirou/gopsutil/v4 v4.26.3 h1:2ESdQt90yU3oXF/CdOlRCJxrP+Am1aBYubTMTfxJ1qc= @@ -118,6 +158,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/stripe/stripe-go/v82 v82.5.1 h1:05q6ZDKoe8PLMpQV072obF74HCgP4XJeJYoNuRSX2+8= +github.com/stripe/stripe-go/v82 v82.5.1/go.mod h1:majCQX6AfObAvJiHraPi/5udwHi4ojRvJnnxckvHrX8= github.com/testcontainers/testcontainers-go v0.42.0 h1:He3IhTzTZOygSXLJPMX7n44XtK+qhjat1nI9cneBbUY= github.com/testcontainers/testcontainers-go v0.42.0/go.mod h1:vZjdY1YmUA1qEForxOIOazfsrdyORJAbhi0bp8plN30= github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 h1:GCbb1ndrF7OTDiIvxXyItaDab4qkzTFJ48LKFdM7EIo= @@ -126,6 +168,11 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= +github.com/twilio/twilio-go v1.30.9 h1:4W4GEV2q0sLQ9xsr1N/97JQlt0c82hZ0ij4qTErstv8= +github.com/twilio/twilio-go v1.30.9/go.mod h1:QbitvbvtkV77Jn4BABAKVmxabYSjMyQG4tHey9gfPqg= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= @@ -142,22 +189,48 @@ go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfC go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= diff --git a/internal/api/activity.go b/internal/api/activity.go index c0dc732..13600e8 100644 --- a/internal/api/activity.go +++ b/internal/api/activity.go @@ -1,12 +1,10 @@ package api import ( - "errors" "net/http" "sort" "time" - "github.com/alchemistkay/guestguard/internal/domain" "github.com/alchemistkay/guestguard/internal/storage" ) @@ -42,16 +40,15 @@ type activityItem struct { // for an event, sorted newest first. Frontends use this on dashboard mount // to backfill the live monitor with history. func (h *activityHandler) list(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } eventID, ok := parseIDParam(w, r, "id") if !ok { return } - if _, err := h.events.Get(r.Context(), eventID); err != nil { - if errors.Is(err, domain.ErrEventNotFound) { - writeError(w, http.StatusNotFound, "event not found") - return - } - writeError(w, http.StatusInternalServerError, "failed to load event") + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { return } diff --git a/internal/api/auth.go b/internal/api/auth.go new file mode 100644 index 0000000..f6673c3 --- /dev/null +++ b/internal/api/auth.go @@ -0,0 +1,557 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "net/http" + "net/mail" + "net/url" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/auth" + "github.com/alchemistkay/guestguard/internal/domain" + "github.com/alchemistkay/guestguard/internal/ratelimit" + "github.com/alchemistkay/guestguard/internal/storage" +) + +const refreshCookieName = "gg_refresh" + +type authHandler struct { + logger *slog.Logger + users *storage.UserRepo + verifications *storage.EmailVerificationRepo + resets *storage.PasswordResetRepo + refreshes *storage.RefreshTokenRepo + hasher *auth.PasswordHasher + signer *auth.JWTSigner + emails auth.EmailSender + lockout *auth.LockoutTracker + limiter *ratelimit.Limiter + + publicBaseURL string + emailVerificationTTL time.Duration + passwordResetTTL time.Duration + refreshTTL time.Duration + cookieDomain string + cookieSecure bool +} + +type authHandlerDeps struct { + Logger *slog.Logger + Users *storage.UserRepo + Verifications *storage.EmailVerificationRepo + Resets *storage.PasswordResetRepo + Refreshes *storage.RefreshTokenRepo + Hasher *auth.PasswordHasher + Signer *auth.JWTSigner + Emails auth.EmailSender + Lockout *auth.LockoutTracker + Limiter *ratelimit.Limiter + + PublicBaseURL string + EmailVerificationTTL time.Duration + PasswordResetTTL time.Duration + RefreshTTL time.Duration + CookieDomain string + CookieSecure bool +} + +func newAuthHandler(d authHandlerDeps) *authHandler { + return &authHandler{ + logger: d.Logger, + users: d.Users, + verifications: d.Verifications, + resets: d.Resets, + refreshes: d.Refreshes, + hasher: d.Hasher, + signer: d.Signer, + emails: d.Emails, + lockout: d.Lockout, + limiter: d.Limiter, + publicBaseURL: strings.TrimRight(d.PublicBaseURL, "/"), + emailVerificationTTL: d.EmailVerificationTTL, + passwordResetTTL: d.PasswordResetTTL, + refreshTTL: d.RefreshTTL, + cookieDomain: d.CookieDomain, + cookieSecure: d.CookieSecure, + } +} + +// --- request/response types --- + +type signupRequest struct { + Email string `json:"email"` + Name string `json:"name"` + Password string `json:"password"` + AcceptTerms bool `json:"accept_terms"` +} + +type loginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +type verifyEmailRequest struct { + Token string `json:"token"` +} + +type forgotPasswordRequest struct { + Email string `json:"email"` +} + +type resetPasswordRequest struct { + Token string `json:"token"` + NewPassword string `json:"new_password"` +} + +type authSuccess struct { + AccessToken string `json:"access_token"` + ExpiresAt time.Time `json:"expires_at"` + User *domain.User `json:"user"` +} + +// --- handlers --- + +// POST /auth/signup +func (h *authHandler) signup(w http.ResponseWriter, r *http.Request) { + var req signupRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + if _, err := mail.ParseAddress(req.Email); err != nil { + writeError(w, http.StatusBadRequest, "email is invalid") + return + } + if strings.TrimSpace(req.Name) == "" { + writeError(w, http.StatusBadRequest, "name is required") + return + } + hash, err := h.hasher.Hash(req.Password) + if err != nil { + if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + h.logger.Error("hash password", "err", err) + writeError(w, http.StatusInternalServerError, "failed to create user") + return + } + + u, err := h.users.Create(r.Context(), storage.CreateUserParams{ + Email: req.Email, + Name: req.Name, + PasswordHash: hash, + AcceptTerms: req.AcceptTerms, + }) + if err != nil { + if errors.Is(err, domain.ErrEmailTaken) { + // Don't leak which addresses are registered. Still return 201 and + // trigger a "if-you-already-have-an-account" email asynchronously + // (skipped for the stub). On real auth this should send a "you + // tried to sign up again, here's a reset link" email. + h.logger.Info("signup attempted with existing email", "email", req.Email) + writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"}) + return + } + h.logger.Error("create user", "err", err) + writeError(w, http.StatusInternalServerError, "failed to create user") + return + } + + if err := h.sendVerificationEmail(r.Context(), u); err != nil { + h.logger.Error("send verification email", "err", err, "user_id", u.ID) + // Don't fail the signup — user can request a resend. + } + writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"}) +} + +// POST /auth/login +func (h *authHandler) login(w http.ResponseWriter, r *http.Request) { + var req loginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + if req.Email == "" || req.Password == "" { + writeError(w, http.StatusBadRequest, "email and password required") + return + } + + // Per-(IP + email) sliding-window — 10 per 5 minutes per the plan. + if !h.checkRate(w, r, "login", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)), + 10, 5*time.Minute) { + return + } + + u, err := h.users.GetByEmail(r.Context(), req.Email) + if err != nil || u.PasswordHash == "" { + _, _ = h.lockout.RecordFailure(r.Context(), req.Email, nil) + writeError(w, http.StatusUnauthorized, "invalid email or password") + return + } + + // If the account is already locked, reject before doing a bcrypt compare. + locked, _ := h.lockout.IsLocked(r.Context(), u.ID) + if locked { + writeError(w, http.StatusForbidden, "account locked — reset your password to unlock") + return + } + + if err := h.hasher.Verify(u.PasswordHash, req.Password); err != nil { + locked, _ := h.lockout.RecordFailure(r.Context(), req.Email, &u.ID) + if locked { + writeError(w, http.StatusForbidden, "account locked — reset your password to unlock") + return + } + writeError(w, http.StatusUnauthorized, "invalid email or password") + return + } + if !u.EmailVerified { + writeError(w, http.StatusForbidden, "email not verified") + return + } + + h.lockout.ClearOnSuccess(r.Context(), req.Email) + if err := h.issueSession(w, r, u); err != nil { + h.logger.Error("issue session", "err", err, "user_id", u.ID) + writeError(w, http.StatusInternalServerError, "failed to start session") + return + } +} + +// checkRate consults the limiter (when one is configured) and writes a 429 +// response if the budget is exhausted. Returns false if the caller should +// stop handling the request. +func (h *authHandler) checkRate(w http.ResponseWriter, r *http.Request, name, key string, limit int, window time.Duration) bool { + if h.limiter == nil || key == "" { + return true + } + res, err := h.limiter.Allow(r.Context(), name, key, limit, window) + if err != nil { + h.logger.Warn("ratelimit error (failing open)", "rule", name, "err", err) + return true + } + if !res.Allowed { + retry := int(res.RetryAfter.Round(time.Second).Seconds()) + if retry < 1 { + retry = 1 + } + w.Header().Set("Retry-After", strconv.Itoa(retry)) + writeJSON(w, http.StatusTooManyRequests, map[string]any{ + "error": "rate limit exceeded", + "retry_after": retry, + }) + return false + } + return true +} + +// POST /auth/refresh +func (h *authHandler) refresh(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(refreshCookieName) + if err != nil || cookie.Value == "" { + writeError(w, http.StatusUnauthorized, "missing refresh token") + return + } + oldHash := auth.HashOpaque(cookie.Value) + + existing, err := h.refreshes.Get(r.Context(), oldHash) + if err != nil { + if errors.Is(err, domain.ErrAuthTokenNotFound) { + h.clearRefreshCookie(w) + writeError(w, http.StatusUnauthorized, "invalid refresh token") + return + } + h.logger.Error("lookup refresh", "err", err) + writeError(w, http.StatusInternalServerError, "refresh failed") + return + } + if existing.RevokedAt != nil { + // Replay of a revoked token. Revoke the family. + _ = h.refreshes.RevokeAllForUser(r.Context(), existing.UserID) + h.clearRefreshCookie(w) + writeError(w, http.StatusUnauthorized, "refresh token reused") + return + } + if time.Now().After(existing.ExpiresAt) { + h.clearRefreshCookie(w) + writeError(w, http.StatusUnauthorized, "refresh token expired") + return + } + + u, err := h.users.GetByID(r.Context(), existing.UserID) + if err != nil { + h.clearRefreshCookie(w) + writeError(w, http.StatusUnauthorized, "user not found") + return + } + + newRaw, newHash, err := auth.NewOpaqueToken() + if err != nil { + h.logger.Error("mint refresh", "err", err) + writeError(w, http.StatusInternalServerError, "refresh failed") + return + } + exp := time.Now().Add(h.refreshTTL) + if err := h.refreshes.Rotate(r.Context(), oldHash, storage.CreateRefreshTokenParams{ + Hash: newHash, + UserID: u.ID, + ExpiresAt: exp, + UserAgent: r.UserAgent(), + IPAddress: clientIP(r), + }); err != nil { + if errors.Is(err, domain.ErrRefreshTokenRevoked) { + h.clearRefreshCookie(w) + writeError(w, http.StatusUnauthorized, "refresh token reused") + return + } + h.logger.Error("rotate refresh", "err", err) + writeError(w, http.StatusInternalServerError, "refresh failed") + return + } + + access, accessExp, err := h.signer.Issue(u.ID, time.Now()) + if err != nil { + h.logger.Error("sign access", "err", err) + writeError(w, http.StatusInternalServerError, "refresh failed") + return + } + h.setRefreshCookie(w, newRaw, exp) + writeJSON(w, http.StatusOK, authSuccess{ + AccessToken: access, + ExpiresAt: accessExp, + User: u, + }) +} + +// POST /auth/logout +func (h *authHandler) logout(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(refreshCookieName) + if err == nil && cookie.Value != "" { + _ = h.refreshes.Revoke(r.Context(), auth.HashOpaque(cookie.Value)) + } + h.clearRefreshCookie(w) + w.WriteHeader(http.StatusNoContent) +} + +// POST /auth/verify-email +func (h *authHandler) verifyEmail(w http.ResponseWriter, r *http.Request) { + var req verifyEmailRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" { + writeError(w, http.StatusBadRequest, "token required") + return + } + uid, err := h.verifications.Consume(r.Context(), auth.HashOpaque(req.Token)) + if err != nil { + switch { + case errors.Is(err, domain.ErrAuthTokenNotFound): + writeError(w, http.StatusBadRequest, "invalid token") + case errors.Is(err, domain.ErrAuthTokenConsumed): + writeError(w, http.StatusBadRequest, "token already used") + case errors.Is(err, domain.ErrAuthTokenExpired): + writeError(w, http.StatusBadRequest, "token expired") + default: + h.logger.Error("consume verification", "err", err) + writeError(w, http.StatusInternalServerError, "verification failed") + } + return + } + if err := h.users.MarkEmailVerified(r.Context(), uid); err != nil { + h.logger.Error("mark verified", "err", err, "user_id", uid) + writeError(w, http.StatusInternalServerError, "verification failed") + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "verified"}) +} + +// POST /auth/forgot-password +func (h *authHandler) forgotPassword(w http.ResponseWriter, r *http.Request) { + var req forgotPasswordRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + if !h.checkRate(w, r, "forgot_password", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)), + 3, time.Hour) { + return + } + // Always respond 202 to avoid leaking whether the email exists. + defer func() { writeJSON(w, http.StatusAccepted, map[string]string{"status": "if_known_email_sent"}) }() + + u, err := h.users.GetByEmail(r.Context(), req.Email) + if err != nil { + return + } + raw, hash, err := auth.NewOpaqueToken() + if err != nil { + h.logger.Error("mint reset", "err", err) + return + } + exp := time.Now().Add(h.passwordResetTTL) + if err := h.resets.Create(r.Context(), u.ID, hash, exp); err != nil { + h.logger.Error("persist reset", "err", err) + return + } + link := h.publicBaseURL + "/reset-password/" + url.PathEscape(raw) + if err := h.emails.SendPasswordReset(r.Context(), u.Email, u.Name, link); err != nil { + h.logger.Error("send reset email", "err", err) + } +} + +// POST /auth/reset-password +func (h *authHandler) resetPassword(w http.ResponseWriter, r *http.Request) { + var req resetPasswordRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" { + writeError(w, http.StatusBadRequest, "token and new_password required") + return + } + newHash, err := h.hasher.Hash(req.NewPassword) + if err != nil { + if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + h.logger.Error("hash password", "err", err) + writeError(w, http.StatusInternalServerError, "reset failed") + return + } + uid, err := h.resets.Consume(r.Context(), auth.HashOpaque(req.Token)) + if err != nil { + switch { + case errors.Is(err, domain.ErrAuthTokenNotFound): + writeError(w, http.StatusBadRequest, "invalid token") + case errors.Is(err, domain.ErrAuthTokenConsumed): + writeError(w, http.StatusBadRequest, "token already used") + case errors.Is(err, domain.ErrAuthTokenExpired): + writeError(w, http.StatusBadRequest, "token expired") + default: + h.logger.Error("consume reset", "err", err) + writeError(w, http.StatusInternalServerError, "reset failed") + } + return + } + if err := h.users.UpdatePasswordHash(r.Context(), uid, newHash); err != nil { + h.logger.Error("update password", "err", err, "user_id", uid) + writeError(w, http.StatusInternalServerError, "reset failed") + return + } + // Invalidate all existing sessions. + _ = h.refreshes.RevokeAllForUser(r.Context(), uid) + // Resetting the password is the canonical "unlock" path for the + // account lockout that triggers after repeated bad-credential attempts. + if u, err := h.users.GetByID(r.Context(), uid); err == nil { + _ = h.lockout.ClearForUser(r.Context(), uid, u.Email) + } + writeJSON(w, http.StatusOK, map[string]string{"status": "password_reset"}) +} + +// --- helpers --- + +func (h *authHandler) sendVerificationEmail(ctx context.Context, u *domain.User) error { + raw, hash, err := auth.NewOpaqueToken() + if err != nil { + return err + } + if err := h.verifications.Create(ctx, u.ID, hash, time.Now().Add(h.emailVerificationTTL)); err != nil { + return err + } + link := h.publicBaseURL + "/verify-email?token=" + url.QueryEscape(raw) + return h.emails.SendVerification(ctx, u.Email, u.Name, link) +} + +func (h *authHandler) issueSession(w http.ResponseWriter, r *http.Request, u *domain.User) error { + access, accessExp, err := h.signer.Issue(u.ID, time.Now()) + if err != nil { + return err + } + raw, hash, err := auth.NewOpaqueToken() + if err != nil { + return err + } + refreshExp := time.Now().Add(h.refreshTTL) + if err := h.refreshes.Create(r.Context(), storage.CreateRefreshTokenParams{ + Hash: hash, + UserID: u.ID, + ExpiresAt: refreshExp, + UserAgent: r.UserAgent(), + IPAddress: clientIP(r), + }); err != nil { + return err + } + h.setRefreshCookie(w, raw, refreshExp) + writeJSON(w, http.StatusOK, authSuccess{ + AccessToken: access, + ExpiresAt: accessExp, + User: u, + }) + return nil +} + +func (h *authHandler) setRefreshCookie(w http.ResponseWriter, value string, expires time.Time) { + http.SetCookie(w, &http.Cookie{ + Name: refreshCookieName, + Value: value, + Path: "/auth", + Domain: h.cookieDomain, + Expires: expires, + MaxAge: int(time.Until(expires).Seconds()), + HttpOnly: true, + Secure: h.cookieSecure, + SameSite: http.SameSiteLaxMode, + }) +} + +func (h *authHandler) clearRefreshCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: refreshCookieName, + Value: "", + Path: "/auth", + Domain: h.cookieDomain, + MaxAge: -1, + HttpOnly: true, + Secure: h.cookieSecure, + SameSite: http.SameSiteLaxMode, + }) +} + +// --- requireAuth middleware --- + +type ctxKey int + +const userIDCtxKey ctxKey = iota + +func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) { + v, ok := ctx.Value(userIDCtxKey).(uuid.UUID) + return v, ok +} + +func requireAuth(signer *auth.JWTSigner) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := r.Header.Get("Authorization") + if !strings.HasPrefix(h, "Bearer ") { + writeError(w, http.StatusUnauthorized, "missing bearer token") + return + } + raw := strings.TrimSpace(strings.TrimPrefix(h, "Bearer ")) + claims, err := signer.Parse(raw) + if err != nil { + if errors.Is(err, auth.ErrExpiredJWT) { + writeError(w, http.StatusUnauthorized, "token expired") + return + } + writeError(w, http.StatusUnauthorized, "invalid token") + return + } + ctx := context.WithValue(r.Context(), userIDCtxKey, claims.UserID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/api/authz.go b/internal/api/authz.go new file mode 100644 index 0000000..36ee0e5 --- /dev/null +++ b/internal/api/authz.go @@ -0,0 +1,44 @@ +package api + +import ( + "errors" + "net/http" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/domain" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// hostFromContext returns the authed user's id, or writes 401 and returns +// false. Used by host-facing handlers as the first line in the function. +func hostFromContext(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { + uid, ok := UserIDFromContext(r.Context()) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthenticated") + return uuid.Nil, false + } + return uid, true +} + +// requireEventOwner fetches the event and confirms the authed user owns it. +// On mismatch (or missing event) it returns 404 — never 403 — so a cross- +// tenant probe cannot tell the difference between "event doesn't exist" and +// "exists but belongs to someone else". +func requireEventOwner( + w http.ResponseWriter, + r *http.Request, + events *storage.EventRepo, + eventID, hostID uuid.UUID, +) (*domain.Event, bool) { + ev, err := events.GetForHost(r.Context(), eventID, hostID) + if err != nil { + if errors.Is(err, domain.ErrEventNotFound) { + writeError(w, http.StatusNotFound, "event not found") + return nil, false + } + writeError(w, http.StatusInternalServerError, "failed to load event") + return nil, false + } + return ev, true +} diff --git a/internal/api/billing.go b/internal/api/billing.go new file mode 100644 index 0000000..b725dc9 --- /dev/null +++ b/internal/api/billing.go @@ -0,0 +1,206 @@ +package api + +import ( + "encoding/json" + "errors" + "log/slog" + "net/http" + "strings" + + "github.com/alchemistkay/guestguard/internal/billing" + "github.com/alchemistkay/guestguard/internal/storage" +) + +type billingHandler struct { + logger *slog.Logger + stripe *billing.Client + users *storage.UserRepo + subscriptions *storage.SubscriptionRepo + publicBaseURL string +} + +type checkoutSessionRequest struct { + Tier string `json:"tier"` +} + +type checkoutSessionResponse struct { + URL string `json:"url"` +} + +// POST /billing/checkout-session — returns the Stripe Checkout URL the +// frontend redirects the host to. Mints a Stripe customer on first use +// and persists it so repeat calls reuse the same customer. +func (h *billingHandler) checkoutSession(w http.ResponseWriter, r *http.Request) { + if !h.stripeEnabled(w) { + return + } + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + var req checkoutSessionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + tier := billing.Tier(strings.ToLower(req.Tier)) + if tier != billing.TierPro && tier != billing.TierBusiness { + writeError(w, http.StatusBadRequest, "tier must be 'pro' or 'business'") + return + } + price, err := h.stripe.PriceFor(tier) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "this tier is not configured yet — contact support") + return + } + + user, err := h.users.GetByID(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load user") + return + } + + existingCustomerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load billing record") + return + } + customerID, err := h.stripe.CreateOrGetCustomer(hostID.String(), user.Email, user.Name, existingCustomerID) + if err != nil { + h.logger.Error("stripe customer", "err", err) + writeError(w, http.StatusBadGateway, "stripe customer error") + return + } + if existingCustomerID == "" { + // First time — write a placeholder row so the customer id sticks. + if _, err := h.subscriptions.Upsert(r.Context(), storage.UpsertParams{ + UserID: hostID, + StripeCustomerID: customerID, + }); err != nil { + h.logger.Error("upsert sub placeholder", "err", err) + } + } + + base := strings.TrimRight(h.publicBaseURL, "/") + url, err := h.stripe.CreateCheckoutSession(billing.CheckoutSessionParams{ + CustomerID: customerID, + PriceID: price, + SuccessURL: base + "/dashboard?billing=success", + CancelURL: base + "/dashboard?billing=cancelled", + }) + if err != nil { + h.logger.Error("stripe checkout session", "err", err) + writeError(w, http.StatusBadGateway, "stripe checkout error") + return + } + writeJSON(w, http.StatusOK, checkoutSessionResponse{URL: url}) +} + +type portalSessionResponse struct { + URL string `json:"url"` +} + +// POST /billing/portal — returns the customer portal URL so the user +// can manage their payment method, view invoices, or cancel. 404 when +// the user has no Stripe customer yet (they're still on free). +func (h *billingHandler) portalSession(w http.ResponseWriter, r *http.Request) { + if !h.stripeEnabled(w) { + return + } + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + customerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load billing record") + return + } + if customerID == "" { + writeError(w, http.StatusNotFound, "no billing account yet — subscribe first") + return + } + url, err := h.stripe.CreatePortalSession(customerID, strings.TrimRight(h.publicBaseURL, "/")+"/dashboard") + if err != nil { + h.logger.Error("stripe portal", "err", err) + writeError(w, http.StatusBadGateway, "stripe portal error") + return + } + writeJSON(w, http.StatusOK, portalSessionResponse{URL: url}) +} + +type subscriptionStatusResponse struct { + Tier string `json:"tier"` + Status string `json:"status"` + CurrentPeriodEnd string `json:"current_period_end,omitempty"` + CancelAtPeriodEnd bool `json:"cancel_at_period_end"` + Limits struct { + EventsPerMonth int `json:"events_per_month"` + GuestsPerEvent int `json:"guests_per_event"` + } `json:"limits"` + Usage struct { + EventsThisMonth int `json:"events_this_month"` + } `json:"usage"` + PortalAvailable bool `json:"portal_available"` +} + +// GET /billing/status — returns the host's current tier + limits + +// usage. The frontend uses this to render the billing page and the +// 402-modal copy ("you used X of Y events this month"). +func (h *billingHandler) status(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + + tier := billing.TierFree + status := "active" + var periodEnd string + cancelAtPeriodEnd := false + portalAvailable := false + + if h.subscriptions != nil { + sub, err := h.subscriptions.GetActiveByUser(r.Context(), hostID) + switch { + case err == nil: + tier = billing.Tier(sub.Tier) + status = sub.Status + cancelAtPeriodEnd = sub.CancelAtPeriodEnd + if sub.CurrentPeriodEnd != nil { + periodEnd = sub.CurrentPeriodEnd.Format("2006-01-02T15:04:05Z") + } + portalAvailable = sub.StripeCustomerID != "" + case errors.Is(err, storage.ErrSubscriptionNotFound): + // Free tier — leave defaults. + default: + writeError(w, http.StatusInternalServerError, "failed to load subscription") + return + } + } + + limits := billing.LimitsFor(tier) + events, _ := h.subscriptions.CountEventsInCurrentMonth(r.Context(), hostID) + + resp := subscriptionStatusResponse{ + Tier: string(tier), + Status: status, + CurrentPeriodEnd: periodEnd, + CancelAtPeriodEnd: cancelAtPeriodEnd, + PortalAvailable: portalAvailable, + } + resp.Limits.EventsPerMonth = limits.EventsPerMonth + resp.Limits.GuestsPerEvent = limits.GuestsPerEvent + resp.Usage.EventsThisMonth = events + writeJSON(w, http.StatusOK, resp) +} + +// stripeEnabled returns true if the billing client is configured, else +// writes 503 and returns false. The /billing/status path skips this so +// the frontend can render a "free tier" page in dev environments. +func (h *billingHandler) stripeEnabled(w http.ResponseWriter) bool { + if h.stripe == nil || !h.stripe.Enabled() { + writeError(w, http.StatusServiceUnavailable, "billing is not configured on this instance") + return false + } + return true +} diff --git a/internal/api/billing_enforce.go b/internal/api/billing_enforce.go new file mode 100644 index 0000000..25a956e --- /dev/null +++ b/internal/api/billing_enforce.go @@ -0,0 +1,157 @@ +package api + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/billing" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// tierEnforcer wraps the SubscriptionRepo with policy decisions. Lives +// here (not in storage) because the policy is HTTP-shaped: we map the +// outcome to 402 + an upgrade URL. +type tierEnforcer struct { + subs *storage.SubscriptionRepo + publicBaseURL string +} + +func newTierEnforcer(subs *storage.SubscriptionRepo, publicBaseURL string) *tierEnforcer { + return &tierEnforcer{subs: subs, publicBaseURL: publicBaseURL} +} + +// currentTier returns the host's effective tier. ErrSubscriptionNotFound +// means "no granting subscription on file" → free. Other DB errors +// bubble up. +func (e *tierEnforcer) currentTier(ctx context.Context, hostID uuid.UUID) (billing.Tier, error) { + if e == nil || e.subs == nil { + return billing.TierFree, nil + } + sub, err := e.subs.GetActiveByUser(ctx, hostID) + if err != nil { + if errors.Is(err, storage.ErrSubscriptionNotFound) { + return billing.TierFree, nil + } + return "", err + } + if !billing.StatusGrantsAccess(sub.Status) { + return billing.TierFree, nil + } + tier := billing.Tier(sub.Tier) + if !tier.Valid() { + return billing.TierFree, nil + } + return tier, nil +} + +// allowEventCreate verifies the host's monthly event budget. Returns +// true when the request may proceed. On denial it writes a 402 with the +// upgrade hint and returns false. +func (e *tierEnforcer) allowEventCreate(w http.ResponseWriter, r *http.Request, hostID uuid.UUID) bool { + if e == nil || e.subs == nil { + return true + } + tier, err := e.currentTier(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to check plan") + return false + } + limit := billing.LimitsFor(tier).EventsPerMonth + if limit < 0 { + return true + } + used, err := e.subs.CountEventsInCurrentMonth(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to count events") + return false + } + if used >= limit { + e.writePaymentRequired(w, "events_per_month", tier, used, limit, + "You've reached your monthly event limit on the "+strings.ToUpper(string(tier))+" plan.") + return false + } + return true +} + +// allowGuestCreate verifies the per-event guest cap. Same shape as +// allowEventCreate. +func (e *tierEnforcer) allowGuestCreate(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID) bool { + if e == nil || e.subs == nil { + return true + } + tier, err := e.currentTier(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to check plan") + return false + } + limit := billing.LimitsFor(tier).GuestsPerEvent + if limit < 0 { + return true + } + used, err := e.subs.CountGuestsByEvent(r.Context(), eventID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to count guests") + return false + } + if used >= limit { + e.writePaymentRequired(w, "guests_per_event", tier, used, limit, + "This event has reached the guest limit on the "+strings.ToUpper(string(tier))+" plan.") + return false + } + return true +} + +// allowGuestImport is the CSV-import variant: check the cap against +// existing + incoming row count up-front, before we start the +// transaction. Dedup may shrink the actual insert count later — that's +// OK, we just stay on the safe side. +func (e *tierEnforcer) allowGuestImport(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID, incoming int) bool { + if e == nil || e.subs == nil || incoming == 0 { + return true + } + tier, err := e.currentTier(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to check plan") + return false + } + limit := billing.LimitsFor(tier).GuestsPerEvent + if limit < 0 { + return true + } + used, err := e.subs.CountGuestsByEvent(r.Context(), eventID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to count guests") + return false + } + if used+incoming > limit { + e.writePaymentRequired(w, "guests_per_event", tier, used, limit, + "This import would exceed the guest limit on the "+strings.ToUpper(string(tier))+" plan.") + return false + } + return true +} + +type paymentRequiredBody struct { + Error string `json:"error"` + Reason string `json:"reason"` + Tier string `json:"tier"` + Used int `json:"used"` + Limit int `json:"limit"` + UpgradeURL string `json:"upgrade_url"` +} + +func (e *tierEnforcer) writePaymentRequired(w http.ResponseWriter, reason string, tier billing.Tier, used, limit int, msg string) { + body := paymentRequiredBody{ + Error: msg, + Reason: reason, + Tier: string(tier), + Used: used, + Limit: limit, + UpgradeURL: strings.TrimRight(e.publicBaseURL, "/") + "/dashboard/billing", + } + writeJSON(w, http.StatusPaymentRequired, body) +} diff --git a/internal/api/csv_import.go b/internal/api/csv_import.go new file mode 100644 index 0000000..7933bb4 --- /dev/null +++ b/internal/api/csv_import.go @@ -0,0 +1,177 @@ +package api + +import ( + "errors" + "io" + "net/http" + + "github.com/alchemistkay/guestguard/internal/csvimport" + "github.com/alchemistkay/guestguard/internal/storage" +) + +const ( + // 1 MB cap on uploads. With ~200 bytes per row that's ~5,000 guests — + // matches the row cap in csvimport.DefaultMaxRows. + csvMaxBytes = 1 << 20 +) + +type csvImportHandler struct { + guests *storage.GuestRepo + events *storage.EventRepo + enforcer *tierEnforcer +} + +type importResponse struct { + Added int `json:"added"` + Skipped int `json:"skipped"` + SkippedEmails []string `json:"skipped_emails,omitempty"` + Errors []csvimport.RowError `json:"errors,omitempty"` + TotalCount int `json:"total_count"` +} + +type previewResponse struct { + Rows []csvimport.Row `json:"rows"` + Errors []csvimport.RowError `json:"errors,omitempty"` + TotalCount int `json:"total_count"` +} + +// POST /events/{id}/guests/import/preview — parse + validate but don't write. +// Used by the frontend to show a "is this what you meant?" table before commit. +func (h *csvImportHandler) preview(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + + body, ok := readCSVUpload(w, r) + if !ok { + return + } + defer body.Close() + + res, err := csvimport.Parse(body, csvimport.Options{}) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + writeJSON(w, http.StatusOK, previewResponse{ + Rows: res.Rows, + Errors: res.Errors, + TotalCount: res.TotalCount, + }) +} + +// POST /events/{id}/guests/import — parse, validate, and commit valid rows +// in a single transaction. Rows with row-level errors are reported back +// but don't prevent the rest from importing. +func (h *csvImportHandler) commit(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + + body, ok := readCSVUpload(w, r) + if !ok { + return + } + defer body.Close() + + parsed, err := csvimport.Parse(body, csvimport.Options{}) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + + // Plan enforcement: prevent an import that, even if perfectly + // dedup-free, would exceed the per-event guest cap. We check upfront + // (current count + parsed-row count) and reject with 402 if it'd + // overflow. False positives on dedup-heavy CSVs are acceptable — + // host can dedupe and re-upload. + if !h.enforcer.allowGuestImport(w, r, hostID, eventID, len(parsed.Rows)) { + return + } + + rows := make([]storage.BulkImportRow, 0, len(parsed.Rows)) + for _, r := range parsed.Rows { + rows = append(rows, storage.BulkImportRow{ + Name: r.Name, + Email: r.Email, + Phone: r.Phone, + PlusOnes: r.PlusOnes, + }) + } + + res, err := h.guests.BulkImportGuests(r.Context(), eventID, rows) + if err != nil { + writeError(w, http.StatusInternalServerError, "import failed") + return + } + writeJSON(w, http.StatusOK, importResponse{ + Added: res.Added, + Skipped: res.Skipped, + SkippedEmails: res.SkippedEmails, + Errors: parsed.Errors, + TotalCount: parsed.TotalCount, + }) +} + +// GET /events/{id}/guests/import/template — download a sample CSV. Auth is +// applied at the route level; ownership is verified so an attacker can't +// probe the existence of an event by hitting this endpoint. +func (h *csvImportHandler) template(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", `attachment; filename="guestguard-import-template.csv"`) + _, _ = w.Write([]byte(csvimport.TemplateCSV)) +} + +// readCSVUpload returns the multipart file body (capped at csvMaxBytes) or +// writes an error and returns (nil, false). Accepted shapes: +// +// - multipart/form-data with a "file" field (the drag-drop UI uses this) +// - any other Content-Type — the raw body is treated as the CSV (curl-friendly) +func readCSVUpload(w http.ResponseWriter, r *http.Request) (io.ReadCloser, bool) { + r.Body = http.MaxBytesReader(w, r.Body, csvMaxBytes) + if err := r.ParseMultipartForm(csvMaxBytes); err == nil && r.MultipartForm != nil { + files := r.MultipartForm.File["file"] + if len(files) == 0 { + writeError(w, http.StatusBadRequest, `form field "file" is required`) + return nil, false + } + f, err := files[0].Open() + if err != nil { + writeError(w, http.StatusBadRequest, "cannot read uploaded file") + return nil, false + } + return f, true + } else if err != nil && !errors.Is(err, http.ErrNotMultipart) { + writeError(w, http.StatusBadRequest, "invalid multipart body") + return nil, false + } + // Fall through: raw body as CSV. + return r.Body, true +} diff --git a/internal/api/events.go b/internal/api/events.go index 75af6a8..6121a7f 100644 --- a/internal/api/events.go +++ b/internal/api/events.go @@ -15,11 +15,11 @@ import ( ) type eventHandler struct { - repo *storage.EventRepo + repo *storage.EventRepo + enforcer *tierEnforcer } type createEventRequest struct { - HostID string `json:"host_id"` Name string `json:"name"` Slug string `json:"slug"` EventDate time.Time `json:"event_date"` @@ -32,6 +32,14 @@ type createEventRequest struct { var slugRe = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`) func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + if !h.enforcer.allowEventCreate(w, r, hostID) { + return + } + var req createEventRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "invalid json") @@ -51,12 +59,6 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) { return } - hostID, err := uuid.Parse(req.HostID) - if err != nil { - writeError(w, http.StatusBadRequest, "host_id must be a valid uuid") - return - } - status := domain.EventStatus(req.Status) if status == "" { status = domain.EventStatusDraft @@ -89,37 +91,31 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) { } func (h *eventHandler) get(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } id, ok := parseIDParam(w, r, "id") if !ok { return } - ev, err := h.repo.Get(r.Context(), id) - if err != nil { - if errors.Is(err, domain.ErrEventNotFound) { - writeError(w, http.StatusNotFound, "event not found") - return - } - writeError(w, http.StatusInternalServerError, "failed to load event") + ev, ok := requireEventOwner(w, r, h.repo, id, hostID) + if !ok { return } writeJSON(w, http.StatusOK, ev) } func (h *eventHandler) list(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + q := r.URL.Query() limit := atoiOr(q.Get("limit"), 50) offset := atoiOr(q.Get("offset"), 0) - var hostID uuid.UUID - if v := q.Get("host_id"); v != "" { - parsed, err := uuid.Parse(v) - if err != nil { - writeError(w, http.StatusBadRequest, "host_id must be a valid uuid") - return - } - hostID = parsed - } - events, err := h.repo.List(r.Context(), hostID, limit, offset) if err != nil { writeError(w, http.StatusInternalServerError, "failed to list events") @@ -146,6 +142,10 @@ type updateEventRequest struct { } func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } id, ok := parseIDParam(w, r, "id") if !ok { return @@ -180,7 +180,7 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) { params.Status = &s } - ev, err := h.repo.Update(r.Context(), id, params) + ev, err := h.repo.Update(r.Context(), id, hostID, params) if err != nil { switch { case errors.Is(err, domain.ErrEventNotFound): @@ -196,11 +196,15 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) { } func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } id, ok := parseIDParam(w, r, "id") if !ok { return } - if err := h.repo.Delete(r.Context(), id); err != nil { + if err := h.repo.Delete(r.Context(), id, hostID); err != nil { if errors.Is(err, domain.ErrEventNotFound) { writeError(w, http.StatusNotFound, "event not found") return @@ -212,7 +216,14 @@ func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) { } func parseIDParam(w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, bool) { - raw := r.PathValue(name) + return parseRawUUID(w, name, r.PathValue(name)) +} + +func parseRawUUID(w http.ResponseWriter, name, raw string) (uuid.UUID, bool) { + if raw == "" { + writeError(w, http.StatusBadRequest, name+" is required") + return uuid.Nil, false + } id, err := uuid.Parse(raw) if err != nil { writeError(w, http.StatusBadRequest, name+" must be a valid uuid") diff --git a/internal/api/guests.go b/internal/api/guests.go index 0a24095..3665bc6 100644 --- a/internal/api/guests.go +++ b/internal/api/guests.go @@ -10,8 +10,9 @@ import ( ) type guestHandler struct { - guests *storage.GuestRepo - events *storage.EventRepo + guests *storage.GuestRepo + events *storage.EventRepo + enforcer *tierEnforcer } type createGuestRequest struct { @@ -24,16 +25,18 @@ type createGuestRequest struct { } func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } eventID, ok := parseIDParam(w, r, "id") if !ok { return } - if _, err := h.events.Get(r.Context(), eventID); err != nil { - if errors.Is(err, domain.ErrEventNotFound) { - writeError(w, http.StatusNotFound, "event not found") - return - } - writeError(w, http.StatusInternalServerError, "failed to load event") + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + if !h.enforcer.allowGuestCreate(w, r, hostID, eventID) { return } @@ -67,11 +70,106 @@ func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusCreated, g) } -func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) { +type updateGuestRequest struct { + Name *string `json:"name"` + Email *string `json:"email"` + Phone *string `json:"phone"` + PlusOnes *int `json:"plus_ones"` +} + +// PATCH /events/{id}/guests/{guest_id} — patch a guest's contact info. +// Fields omitted from the body are left untouched. Empty strings for +// email/phone clear those columns to NULL. +func (h *guestHandler) update(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } eventID, ok := parseIDParam(w, r, "id") if !ok { return } + guestID, ok := parseIDParam(w, r, "guest_id") + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + + var req updateGuestRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + if req.Name != nil && *req.Name == "" { + writeError(w, http.StatusBadRequest, "name cannot be empty") + return + } + if req.PlusOnes != nil && *req.PlusOnes < 0 { + writeError(w, http.StatusBadRequest, "plus_ones must be >= 0") + return + } + + g, err := h.guests.Update(r.Context(), eventID, guestID, storage.UpdateGuestParams{ + Name: req.Name, + Email: req.Email, + Phone: req.Phone, + PlusOnes: req.PlusOnes, + }) + if err != nil { + if errors.Is(err, domain.ErrGuestNotFound) { + writeError(w, http.StatusNotFound, "guest not found") + return + } + writeError(w, http.StatusInternalServerError, "failed to update guest") + return + } + writeJSON(w, http.StatusOK, g) +} + +// DELETE /events/{id}/guests/{guest_id} — remove a guest from an event. +// Cascade-deletes their token, rsvp, access logs, notifications. +func (h *guestHandler) delete(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + guestID, ok := parseIDParam(w, r, "guest_id") + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + if err := h.guests.Delete(r.Context(), eventID, guestID); err != nil { + if errors.Is(err, domain.ErrGuestNotFound) { + writeError(w, http.StatusNotFound, "guest not found") + return + } + writeError(w, http.StatusInternalServerError, "failed to delete guest") + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + q := r.URL.Query() limit := atoiOr(q.Get("limit"), 100) offset := atoiOr(q.Get("offset"), 0) diff --git a/internal/api/me.go b/internal/api/me.go new file mode 100644 index 0000000..4b1e517 --- /dev/null +++ b/internal/api/me.go @@ -0,0 +1,32 @@ +package api + +import ( + "net/http" + + "github.com/alchemistkay/guestguard/internal/domain" + "github.com/alchemistkay/guestguard/internal/storage" +) + +type meHandler struct { + users *storage.UserRepo +} + +// GET /me — returns the authenticated user. Used by the frontend to bootstrap +// after a page reload (with a fresh access token from /auth/refresh). +func (h *meHandler) get(w http.ResponseWriter, r *http.Request) { + uid, ok := UserIDFromContext(r.Context()) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthenticated") + return + } + u, err := h.users.GetByID(r.Context(), uid) + if err != nil { + if err == domain.ErrUserNotFound { + writeError(w, http.StatusUnauthorized, "user not found") + return + } + writeError(w, http.StatusInternalServerError, "failed to load user") + return + } + writeJSON(w, http.StatusOK, u) +} diff --git a/internal/api/privacy.go b/internal/api/privacy.go new file mode 100644 index 0000000..a12592e --- /dev/null +++ b/internal/api/privacy.go @@ -0,0 +1,255 @@ +package api + +import ( + "context" + "errors" + "log/slog" + "net/http" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/domain" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// privacyHandler holds the GDPR-style "your data, your choice" endpoints: +// data export, account deletion, and terms-acceptance recording. +type privacyHandler struct { + logger *slog.Logger + users *storage.UserRepo + events *storage.EventRepo + guests *storage.GuestRepo + tokens *storage.TokenRepo + rsvps *storage.RSVPRepo + access *storage.AccessLogRepo + notifs *storage.DB // raw pool access for the export queries + refresh *storage.RefreshTokenRepo +} + +// DataExport is the shape of the JSON the host downloads from +// GET /me/data-export. We don't paginate or stream — for the scale +// GuestGuard hosts have, a single response is reasonable. If a host +// ever has 100k+ access logs we'll switch to async + email-a-link. +type DataExport struct { + ExportedAt time.Time `json:"exported_at"` + Format string `json:"format"` + User *domain.User `json:"user"` + Events []*domain.Event `json:"events"` + Guests []*domain.Guest `json:"guests"` + Tokens []exportedToken `json:"tokens"` + RSVPs []exportedRSVP `json:"rsvps"` + AccessLogs []exportedAccess `json:"access_logs"` + Notifs []exportedNotif `json:"notifications"` +} + +type exportedToken struct { + ID uuid.UUID `json:"id"` + GuestID uuid.UUID `json:"guest_id"` + ExpiresAt time.Time `json:"expires_at"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` +} +type exportedRSVP struct { + ID uuid.UUID `json:"id"` + GuestID uuid.UUID `json:"guest_id"` + Response string `json:"response"` + PlusOnes int `json:"plus_ones"` + SubmittedAt time.Time `json:"submitted_at"` +} +type exportedAccess struct { + ID uuid.UUID `json:"id"` + GuestID uuid.UUID `json:"guest_id"` + RiskScore *int `json:"risk_score,omitempty"` + Flagged bool `json:"flagged"` + CreatedAt time.Time `json:"created_at"` +} +type exportedNotif struct { + ID uuid.UUID `json:"id"` + GuestID uuid.UUID `json:"guest_id"` + Channel string `json:"channel"` + Type string `json:"type"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` +} + +// GET /me/data-export — returns every record the system holds about the +// authenticated user. The Content-Disposition header makes browsers +// offer a download rather than rendering inline. +func (h *privacyHandler) dataExport(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + + user, err := h.users.GetByID(r.Context(), hostID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load user") + return + } + + export := DataExport{ + ExportedAt: time.Now().UTC(), + Format: "guestguard.v1", + User: user, + } + + // Events the user hosts. + events, err := h.events.List(r.Context(), hostID, 1000, 0) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load events") + return + } + export.Events = events + + // For each event, pull guests + tokens + rsvps + access_logs + notifications. + for _, ev := range events { + guests, err := h.guests.ListByEvent(r.Context(), ev.ID, 5000, 0) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load guests") + return + } + export.Guests = append(export.Guests, guests...) + + for _, g := range guests { + // Token (at most one per guest, but query as a list for symmetry). + if err := h.appendTokens(r.Context(), g.ID, &export); err != nil { + h.logger.Warn("export: tokens", "err", err) + } + if err := h.appendRSVPs(r.Context(), g.ID, &export); err != nil { + h.logger.Warn("export: rsvps", "err", err) + } + if err := h.appendAccess(r.Context(), g.ID, &export); err != nil { + h.logger.Warn("export: access", "err", err) + } + if err := h.appendNotifs(r.Context(), g.ID, &export); err != nil { + h.logger.Warn("export: notifs", "err", err) + } + } + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Content-Disposition", `attachment; filename="guestguard-data-export.json"`) + writeJSON(w, http.StatusOK, export) +} + +func (h *privacyHandler) appendTokens(ctx context.Context, guestID uuid.UUID, out *DataExport) error { + rows, err := h.notifs.Pool.Query(ctx, ` + SELECT id, guest_id, expires_at, status, created_at + FROM tokens WHERE guest_id = $1 + `, guestID) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var t exportedToken + if err := rows.Scan(&t.ID, &t.GuestID, &t.ExpiresAt, &t.Status, &t.CreatedAt); err != nil { + return err + } + out.Tokens = append(out.Tokens, t) + } + return rows.Err() +} +func (h *privacyHandler) appendRSVPs(ctx context.Context, guestID uuid.UUID, out *DataExport) error { + rows, err := h.notifs.Pool.Query(ctx, ` + SELECT id, guest_id, response::text, plus_ones, submitted_at + FROM rsvps WHERE guest_id = $1 + `, guestID) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var r exportedRSVP + if err := rows.Scan(&r.ID, &r.GuestID, &r.Response, &r.PlusOnes, &r.SubmittedAt); err != nil { + return err + } + out.RSVPs = append(out.RSVPs, r) + } + return rows.Err() +} +func (h *privacyHandler) appendAccess(ctx context.Context, guestID uuid.UUID, out *DataExport) error { + rows, err := h.notifs.Pool.Query(ctx, ` + SELECT id, guest_id, risk_score, flagged, created_at + FROM access_logs WHERE guest_id = $1 + `, guestID) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var a exportedAccess + var rs *int + if err := rows.Scan(&a.ID, &a.GuestID, &rs, &a.Flagged, &a.CreatedAt); err != nil { + return err + } + a.RiskScore = rs + out.AccessLogs = append(out.AccessLogs, a) + } + return rows.Err() +} +func (h *privacyHandler) appendNotifs(ctx context.Context, guestID uuid.UUID, out *DataExport) error { + rows, err := h.notifs.Pool.Query(ctx, ` + SELECT id, guest_id, channel::text, type::text, status::text, created_at + FROM notifications WHERE guest_id = $1 + `, guestID) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var n exportedNotif + if err := rows.Scan(&n.ID, &n.GuestID, &n.Channel, &n.Type, &n.Status, &n.CreatedAt); err != nil { + return err + } + out.Notifs = append(out.Notifs, n) + } + return rows.Err() +} + +// DELETE /me — soft-deletes the host's account. All sessions are +// revoked immediately. A hard delete happens via a separate cron 30 +// days later (TBD ops work). The user is logged out from all devices +// as a side effect of revoking the refresh tokens. +func (h *privacyHandler) deleteMe(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + if err := h.users.SoftDelete(r.Context(), hostID); err != nil { + if errors.Is(err, domain.ErrUserNotFound) { + writeError(w, http.StatusNotFound, "user not found") + return + } + writeError(w, http.StatusInternalServerError, "failed to delete account") + return + } + // Best-effort: revoke refresh tokens so other sessions log out too. + // Failure here is logged but doesn't roll back the soft-delete — the + // access tokens (JWT) will still expire on their own ~15 minute TTL. + if err := h.refresh.RevokeAllForUser(r.Context(), hostID); err != nil { + h.logger.Warn("delete-me: revoke refresh tokens", "err", err, "user_id", hostID) + } + w.WriteHeader(http.StatusNoContent) +} + +// POST /me/accept-terms — records that the authenticated user accepts +// the current ToS + privacy policy. Idempotent. Used by both the +// onboarding gate (existing accounts created before T&C were enforced) +// and any future "we updated our terms" re-acceptance flow. +func (h *privacyHandler) acceptTerms(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + if err := h.users.AcceptTerms(r.Context(), hostID); err != nil { + if errors.Is(err, domain.ErrUserNotFound) { + writeError(w, http.StatusNotFound, "user not found") + return + } + writeError(w, http.StatusInternalServerError, "failed to record acceptance") + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "accepted"}) +} diff --git a/internal/api/ratelimit_keys.go b/internal/api/ratelimit_keys.go new file mode 100644 index 0000000..991760f --- /dev/null +++ b/internal/api/ratelimit_keys.go @@ -0,0 +1,37 @@ +package api + +import ( + "net/http" +) + +// ipKey is the rate-limit key for endpoints scoped by source IP only +// (e.g. POST /auth/signup). XFF/X-Real-IP are honoured because in the +// homelab the API sits behind Traefik. +func ipKey(r *http.Request) string { + return clientIP(r) +} + +// pathKey returns a path-parameter as the rate-limit key — used for the +// token-scoped endpoints so an attacker brute-forcing a single token is +// limited regardless of the IPs they rotate through. +func pathKey(name string) KeyFunc { + return func(r *http.Request) string { + return r.PathValue(name) + } +} + +// userIDKey extracts the authenticated user id from the request context. +// Returns "" when the route isn't behind requireAuth, in which case the +// middleware bypasses (fail-open) — the route's own auth layer handles +// rejection. +func userIDKey(r *http.Request) string { + uid, ok := UserIDFromContext(r.Context()) + if !ok { + return "" + } + return uid.String() +} + +// KeyFunc mirrors ratelimit.KeyFunc so call sites don't have to import the +// inner package. +type KeyFunc = func(r *http.Request) string diff --git a/internal/api/server.go b/internal/api/server.go index 3df6634..6b37e41 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -5,63 +5,167 @@ import ( "net/http" "time" + "github.com/redis/go-redis/v9" + "github.com/alchemistkay/guestguard/internal/auth" + "github.com/alchemistkay/guestguard/internal/billing" + "github.com/alchemistkay/guestguard/internal/notification" + "github.com/alchemistkay/guestguard/internal/ratelimit" "github.com/alchemistkay/guestguard/internal/storage" ) type Server struct { - logger *slog.Logger - db *storage.DB - hub *Hub - users *userHandler - events *eventHandler - guests *guestHandler - tokens *tokenHandler - rsvps *rsvpHandler - activity *activityHandler - ws *wsHandler - health *healthHandler + logger *slog.Logger + db *storage.DB + hub *Hub + authH *authHandler + me *meHandler + events *eventHandler + guests *guestHandler + tokens *tokenHandler + rsvps *rsvpHandler + activity *activityHandler + ws *wsHandler + wsTicket *wsTicketHandler + health *healthHandler + signer *auth.JWTSigner + limiter *ratelimit.Limiter + unsub *unsubscribeHandler + webhooks *webhookHandler + csv *csvImportHandler + billing *billingHandler + stripeWH *stripeWebhookHandler + privacy *privacyHandler } type ServerDeps struct { Logger *slog.Logger DB *storage.DB Hub *Hub - AccessPublisher accessPublisher - RSVPPublisher rsvpPublisher - FraudScorer fraudScorer - TokenTTL time.Duration + AccessPublisher accessPublisher + RSVPPublisher rsvpPublisher + InvitationPublisher invitationPublisher + FraudScorer fraudScorer + TokenTTL time.Duration + + // Auth + JWTSecret string + JWTIssuer string + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + EmailVerificationTTL time.Duration + PasswordResetTTL time.Duration + PublicBaseURL string + RefreshCookieDomain string + RefreshCookieSecure bool + EmailSender auth.EmailSender + WSTicketTTL time.Duration + + // Rate limiting / abuse controls + Redis *redis.Client + LoginLockoutMax int // failed attempts before account lockout (default 5) + LoginFailWindow time.Duration // counter TTL (default 15 min) + + // Notifications / unsubscribe + NotificationRepo *notification.Repo + SuppressionRepo *notification.SuppressionRepo + UnsubscribeSigner *notification.UnsubscribeSigner + + // Billing (Block F). Nil StripeClient leaves billing disabled — the + // system still boots and runs, all users sit on the free tier with + // its limits enforced; /billing/* returns 503. + StripeClient *billing.Client } -func NewServer(deps ServerDeps) *Server { +func NewServer(deps ServerDeps) (*Server, error) { eventRepo := storage.NewEventRepo(deps.DB) guestRepo := storage.NewGuestRepo(deps.DB) tokenRepo := storage.NewTokenRepo(deps.DB) rsvpRepo := storage.NewRSVPRepo(deps.DB) accessRepo := storage.NewAccessLogRepo(deps.DB) userRepo := storage.NewUserRepo(deps.DB) + verifRepo := storage.NewEmailVerificationRepo(deps.DB) + resetRepo := storage.NewPasswordResetRepo(deps.DB) + refreshRepo := storage.NewRefreshTokenRepo(deps.DB) + subRepo := storage.NewSubscriptionRepo(deps.DB) + enforcer := newTierEnforcer(subRepo, deps.PublicBaseURL) + + signer, err := auth.NewJWTSigner(deps.JWTSecret, deps.AccessTokenTTL, deps.JWTIssuer) + if err != nil { + return nil, err + } + hasher := auth.NewPasswordHasher() + + emails := deps.EmailSender + if emails == nil { + emails = auth.LogEmailSender{Logger: deps.Logger} + } hub := deps.Hub if hub == nil { hub = NewHub(deps.Logger) } + wsTicketTTL := deps.WSTicketTTL + if wsTicketTTL <= 0 { + wsTicketTTL = 60 * time.Second + } + wsTickets := newWSTicketStore(wsTicketTTL) + + var limiter *ratelimit.Limiter + var lockout *auth.LockoutTracker + if deps.Redis != nil { + limiter = ratelimit.New(deps.Redis, "gg:rl") + lockoutMax := deps.LoginLockoutMax + if lockoutMax <= 0 { + lockoutMax = 5 + } + failWindow := deps.LoginFailWindow + if failWindow <= 0 { + failWindow = 15 * time.Minute + } + lockout = auth.NewLockoutTracker(deps.Redis, lockoutMax, failWindow) + } + + authH := newAuthHandler(authHandlerDeps{ + Logger: deps.Logger, + Users: userRepo, + Verifications: verifRepo, + Resets: resetRepo, + Refreshes: refreshRepo, + Hasher: hasher, + Signer: signer, + Emails: emails, + Lockout: lockout, + Limiter: limiter, + PublicBaseURL: deps.PublicBaseURL, + EmailVerificationTTL: deps.EmailVerificationTTL, + PasswordResetTTL: deps.PasswordResetTTL, + RefreshTTL: deps.RefreshTokenTTL, + CookieDomain: deps.RefreshCookieDomain, + CookieSecure: deps.RefreshCookieSecure, + }) + return &Server{ logger: deps.Logger, db: deps.DB, hub: hub, - users: &userHandler{repo: userRepo}, - events: &eventHandler{repo: eventRepo}, - guests: &guestHandler{guests: guestRepo, events: eventRepo}, + authH: authH, + me: &meHandler{users: userRepo}, + events: &eventHandler{repo: eventRepo, enforcer: enforcer}, + guests: &guestHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer}, tokens: &tokenHandler{ - logger: deps.Logger, - guests: guestRepo, - tokens: tokenRepo, - events: eventRepo, - accessLogs: accessRepo, - gen: auth.NewGenerator(), - ttl: deps.TokenTTL, - pub: deps.AccessPublisher, + logger: deps.Logger, + guests: guestRepo, + tokens: tokenRepo, + events: eventRepo, + users: userRepo, + accessLogs: accessRepo, + gen: auth.NewGenerator(), + ttl: deps.TokenTTL, + pub: deps.AccessPublisher, + invitations: deps.InvitationPublisher, + publicBaseURL: deps.PublicBaseURL, }, rsvps: &rsvpHandler{ logger: deps.Logger, @@ -78,9 +182,46 @@ func NewServer(deps ServerDeps) *Server { rsvps: rsvpRepo, accessLogs: accessRepo, }, - ws: &wsHandler{logger: deps.Logger, hub: hub}, - health: &healthHandler{pool: deps.DB.Pool}, - } + ws: &wsHandler{logger: deps.Logger, hub: hub, tickets: wsTickets}, + wsTicket: &wsTicketHandler{tickets: wsTickets, events: eventRepo}, + health: &healthHandler{pool: deps.DB.Pool}, + signer: signer, + limiter: limiter, + unsub: &unsubscribeHandler{ + logger: deps.Logger, + signer: deps.UnsubscribeSigner, + suppress: deps.SuppressionRepo, + }, + webhooks: &webhookHandler{ + logger: deps.Logger, + notifs: deps.NotificationRepo, + suppress: deps.SuppressionRepo, + }, + csv: &csvImportHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer}, + billing: &billingHandler{ + logger: deps.Logger, + stripe: deps.StripeClient, + users: userRepo, + subscriptions: subRepo, + publicBaseURL: deps.PublicBaseURL, + }, + stripeWH: &stripeWebhookHandler{ + logger: deps.Logger, + stripe: deps.StripeClient, + subs: subRepo, + }, + privacy: &privacyHandler{ + logger: deps.Logger, + users: userRepo, + events: eventRepo, + guests: guestRepo, + tokens: tokenRepo, + rsvps: rsvpRepo, + access: accessRepo, + notifs: deps.DB, + refresh: refreshRepo, + }, + }, nil } func (s *Server) Hub() *Hub { return s.hub } @@ -91,25 +232,104 @@ func (s *Server) Handler() http.Handler { mux.HandleFunc("GET /health", s.health.live) mux.HandleFunc("GET /health/ready", s.health.ready) - mux.HandleFunc("POST /users", s.users.upsert) + // Per-route rate limiters (no-op when Redis isn't wired). + authed := requireAuth(s.signer) + rl := func(name string, limit int, window time.Duration, keyFn KeyFunc, h http.Handler) http.Handler { + if s.limiter == nil { + return h + } + return s.limiter.Middleware( + ratelimit.Rule{Name: name, Limit: limit, Window: window}, + keyFn, + s.logger, + )(h) + } - mux.HandleFunc("POST /events", s.events.create) - mux.HandleFunc("GET /events", s.events.list) - mux.HandleFunc("GET /events/{id}", s.events.get) - mux.HandleFunc("PATCH /events/{id}", s.events.update) - mux.HandleFunc("DELETE /events/{id}", s.events.delete) + // Anonymous auth endpoints — POST /auth/login + /auth/forgot-password + // rate-limit inside the handler (key includes the email body field). + mux.Handle("POST /auth/signup", + rl("auth_signup", 5, time.Hour, ipKey, http.HandlerFunc(s.authH.signup))) + mux.HandleFunc("POST /auth/login", s.authH.login) + mux.HandleFunc("POST /auth/refresh", s.authH.refresh) + mux.HandleFunc("POST /auth/logout", s.authH.logout) + mux.HandleFunc("POST /auth/verify-email", s.authH.verifyEmail) + mux.HandleFunc("POST /auth/forgot-password", s.authH.forgotPassword) + mux.HandleFunc("POST /auth/reset-password", s.authH.resetPassword) - mux.HandleFunc("POST /events/{id}/guests", s.guests.create) - mux.HandleFunc("GET /events/{id}/guests", s.guests.list) + mux.Handle("GET /me", authed(http.HandlerFunc(s.me.get))) + mux.Handle("POST /auth/ws-ticket", authed(http.HandlerFunc(s.wsTicket.issue))) - mux.HandleFunc("GET /events/{id}/activity", s.activity.list) + // Privacy / GDPR-style endpoints — host can export their data, + // delete their account, and record terms acceptance from the + // onboarding gate. + mux.Handle("GET /me/data-export", authed(http.HandlerFunc(s.privacy.dataExport))) + mux.Handle("DELETE /me", authed(http.HandlerFunc(s.privacy.deleteMe))) + mux.Handle("POST /me/accept-terms", authed(http.HandlerFunc(s.privacy.acceptTerms))) - mux.HandleFunc("POST /events/{id}/guests/{guest_id}/tokens", s.tokens.issue) - mux.HandleFunc("GET /access/{token}", s.tokens.access) - mux.HandleFunc("POST /rsvp/{token}", s.rsvps.submit) + // Host-facing event/guest/token writes are limited by user_id. + mux.Handle("POST /events", + authed(rl("events_create", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.events.create)))) + mux.Handle("GET /events", authed(http.HandlerFunc(s.events.list))) + mux.Handle("GET /events/{id}", authed(http.HandlerFunc(s.events.get))) + mux.Handle("PATCH /events/{id}", authed(http.HandlerFunc(s.events.update))) + mux.Handle("DELETE /events/{id}", authed(http.HandlerFunc(s.events.delete))) + mux.Handle("POST /events/{id}/guests", + authed(rl("guests_create", 1000, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.create)))) + mux.Handle("GET /events/{id}/guests", authed(http.HandlerFunc(s.guests.list))) + mux.Handle("PATCH /events/{id}/guests/{guest_id}", + authed(rl("guests_update", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.update)))) + mux.Handle("DELETE /events/{id}/guests/{guest_id}", + authed(rl("guests_delete", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.delete)))) + + // CSV import (Block E). Preview is cheap (no DB writes), so we keep + // its budget separate from commit's daily-row-add limit. + mux.Handle("POST /events/{id}/guests/import/preview", + authed(rl("guests_import_preview", 30, time.Hour, userIDKey, http.HandlerFunc(s.csv.preview)))) + mux.Handle("POST /events/{id}/guests/import", + authed(rl("guests_import_commit", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.csv.commit)))) + mux.Handle("GET /events/{id}/guests/import/template", authed(http.HandlerFunc(s.csv.template))) + + mux.Handle("GET /events/{id}/activity", authed(http.HandlerFunc(s.activity.list))) + + mux.Handle("POST /events/{id}/guests/{guest_id}/tokens", + authed(rl("tokens_issue", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.issue)))) + mux.Handle("POST /events/{id}/guests/{guest_id}/tokens/rotate", + authed(rl("tokens_rotate", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.rotate)))) + mux.Handle("POST /events/{id}/guests/invitations/bulk", + authed(rl("tokens_bulk", 10, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.bulkIssue)))) + + // Guest-facing endpoints — rate-limited by the access token in the URL + // path so an attacker hammering a single invitation is slowed regardless + // of their source IP. + mux.Handle("GET /access/{token}", + rl("access", 60, time.Hour, pathKey("token"), http.HandlerFunc(s.tokens.access))) + mux.Handle("POST /rsvp/{token}", + rl("rsvp", 10, time.Hour, pathKey("token"), http.HandlerFunc(s.rsvps.submit))) + + // WebSocket endpoint authenticates via single-use ticket on the query + // string (see POST /auth/ws-ticket). mux.HandleFunc("GET /ws/events/{id}", s.ws.handle) + // Unsubscribe (signed token, no auth required — links live in emails). + mux.HandleFunc("GET /unsubscribe/{token}", s.unsub.preview) + mux.HandleFunc("POST /unsubscribe/{token}", s.unsub.confirm) + + // Provider webhooks. Signature verification is enforced in the handler + // once GG_TWILIO_AUTH_TOKEN / GG_SES_WEBHOOK_SECRET are set. + mux.HandleFunc("POST /webhooks/twilio/status", s.webhooks.twilio) + mux.HandleFunc("POST /webhooks/ses/notifications", s.webhooks.ses) + + // Billing (Block F). /billing/status is safe for everyone — returns + // free tier defaults when Stripe is unconfigured or the user has no + // subscription, so the frontend's plan page always has something to + // render. The action endpoints (checkout, portal) return 503 in dev + // without Stripe credentials. + mux.Handle("GET /billing/status", authed(http.HandlerFunc(s.billing.status))) + mux.Handle("POST /billing/checkout-session", authed(http.HandlerFunc(s.billing.checkoutSession))) + mux.Handle("POST /billing/portal", authed(http.HandlerFunc(s.billing.portalSession))) + mux.HandleFunc("POST /webhooks/stripe", s.stripeWH.handle) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusNotFound, "not found") }) @@ -128,9 +348,10 @@ func corsMiddleware(next http.Handler) http.Handler { origin := r.Header.Get("Origin") if origin != "" { w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Vary", "Origin") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Device-Fingerprint") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Device-Fingerprint") } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) diff --git a/internal/api/stripe_webhook.go b/internal/api/stripe_webhook.go new file mode 100644 index 0000000..fc0b4d5 --- /dev/null +++ b/internal/api/stripe_webhook.go @@ -0,0 +1,163 @@ +package api + +import ( + "encoding/json" + "io" + "log/slog" + "net/http" + "time" + + "github.com/stripe/stripe-go/v82" + + "github.com/alchemistkay/guestguard/internal/billing" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// stripeWebhookHandler accepts and verifies Stripe events, then +// projects subscription lifecycle changes onto the subscriptions table. +// We track only what middleware needs to decide access — tier + status + +// period bounds. Invoice events (payment failed / succeeded) are logged +// for observability; dunning automation lands in Block F3. +type stripeWebhookHandler struct { + logger *slog.Logger + stripe *billing.Client + subs *storage.SubscriptionRepo +} + +// POST /webhooks/stripe — signature-verified Stripe event sink. +func (h *stripeWebhookHandler) handle(w http.ResponseWriter, r *http.Request) { + if h.stripe == nil || !h.stripe.Enabled() { + // Not configured on this instance — reject so a misrouted event + // isn't silently swallowed. Stripe will retry which is harmless. + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + writeError(w, http.StatusBadRequest, "read body") + return + } + defer r.Body.Close() + + event, err := h.stripe.VerifyWebhook(body, r.Header.Get("Stripe-Signature")) + if err != nil { + h.logger.Warn("stripe webhook signature failed", "err", err) + writeError(w, http.StatusBadRequest, "invalid signature") + return + } + + switch event.Type { + case "customer.subscription.created", "customer.subscription.updated": + h.applySubscription(r, event) + case "customer.subscription.deleted": + h.applySubscriptionDeleted(r, event) + case "invoice.payment_succeeded": + // Clear past_due if Stripe says payment caught up. Most flows are + // already covered by the subscription.updated event Stripe also + // fires — this is belt-and-braces. + h.applySubscription(r, event) + case "invoice.payment_failed": + h.logger.Warn("stripe invoice payment failed", "event_id", event.ID) + // Subscription.status will flip to past_due via the + // subscription.updated event Stripe fires alongside. + default: + h.logger.Debug("stripe event ignored", "type", event.Type) + } + + w.WriteHeader(http.StatusOK) +} + +// applySubscription patches the subscriptions row keyed by Stripe +// customer id. Best-effort — failures here are logged but don't NACK +// the webhook (Stripe would retry forever and the row would never +// converge). +func (h *stripeWebhookHandler) applySubscription(r *http.Request, event stripe.Event) { + var sub stripe.Subscription + if err := json.Unmarshal(event.Data.Raw, &sub); err != nil { + // Some invoice events carry an Invoice payload — try to extract + // the subscription id from there and short-circuit on status. + h.logger.Debug("stripe webhook: not a subscription payload", "type", event.Type, "err", err) + return + } + if sub.Customer == nil || sub.Customer.ID == "" { + h.logger.Warn("stripe webhook: subscription has no customer", "subscription", sub.ID) + return + } + + tier := tierFromSubscription(&sub) + status := string(sub.Status) + cancelAtPeriodEnd := sub.CancelAtPeriodEnd + + // As of API 2024-10-28, current_period_end lives on the subscription + // item, not the subscription. We pick the earliest item's end — for + // single-item subscriptions (our case) that's the canonical one. + var periodEnd *time.Time + for _, item := range sub.Items.Data { + if item.CurrentPeriodEnd > 0 { + t := time.Unix(item.CurrentPeriodEnd, 0).UTC() + if periodEnd == nil || t.Before(*periodEnd) { + periodEnd = &t + } + } + } + subID := sub.ID + if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{ + StripeSubscriptionID: &subID, + Tier: stringPtr(string(tier)), + Status: &status, + CurrentPeriodEnd: periodEnd, + CancelAtPeriodEnd: &cancelAtPeriodEnd, + }); err != nil { + h.logger.Error("stripe webhook: update subscription failed", "err", err) + } +} + +func (h *stripeWebhookHandler) applySubscriptionDeleted(r *http.Request, event stripe.Event) { + var sub stripe.Subscription + if err := json.Unmarshal(event.Data.Raw, &sub); err != nil { + h.logger.Warn("stripe webhook: bad deleted payload", "err", err) + return + } + if sub.Customer == nil { + return + } + status := "canceled" + if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{ + Status: &status, + }); err != nil { + h.logger.Error("stripe webhook: mark canceled failed", "err", err) + } +} + +// tierFromSubscription inspects the Stripe price metadata to figure out +// which GuestGuard tier this subscription corresponds to. We read a +// price-level metadata key `gg_tier` (set in the Stripe dashboard when +// you create the Price). Fallback: free. +func tierFromSubscription(sub *stripe.Subscription) billing.Tier { + if sub == nil || len(sub.Items.Data) == 0 { + return billing.TierFree + } + for _, item := range sub.Items.Data { + if item.Price == nil { + continue + } + if v, ok := item.Price.Metadata["gg_tier"]; ok { + t := billing.Tier(v) + if t.Valid() { + return t + } + } + // Heuristic fallback for tests / unconfigured prices: look at the + // recurring interval and amount tier. + if item.Price.Recurring != nil && item.Price.UnitAmount >= 19900 { + return billing.TierBusiness + } + if item.Price.Recurring != nil && item.Price.UnitAmount >= 4900 { + return billing.TierPro + } + } + return billing.TierFree +} + +func stringPtr(s string) *string { return &s } diff --git a/internal/api/tokens.go b/internal/api/tokens.go index 94403dc..6e2134d 100644 --- a/internal/api/tokens.go +++ b/internal/api/tokens.go @@ -2,6 +2,7 @@ package api import ( "context" + "encoding/json" "errors" "log/slog" "net/http" @@ -20,25 +21,38 @@ type accessPublisher interface { PublishAccessAttempted(ctx context.Context, evt natspub.AccessAttempted) error } +type invitationPublisher interface { + PublishInvitationSend(ctx context.Context, evt natspub.InvitationSend) error +} + type tokenHandler struct { - logger *slog.Logger - guests *storage.GuestRepo - tokens *storage.TokenRepo - events *storage.EventRepo - accessLogs *storage.AccessLogRepo - gen *auth.Generator - ttl time.Duration - pub accessPublisher + logger *slog.Logger + guests *storage.GuestRepo + tokens *storage.TokenRepo + events *storage.EventRepo + users *storage.UserRepo + accessLogs *storage.AccessLogRepo + gen *auth.Generator + ttl time.Duration + pub accessPublisher + invitations invitationPublisher + publicBaseURL string } type issueTokenResponse struct { - Token string `json:"token"` - TokenID uuid.UUID `json:"token_id"` - Meta *domain.Token `json:"meta"` + Token string `json:"token"` + TokenID uuid.UUID `json:"token_id"` + Meta *domain.Token `json:"meta"` + InvitationQueued bool `json:"invitation_queued"` + InvitationLink string `json:"invitation_link"` } // POST /events/{id}/guests/{guest_id}/tokens — issue a token for the guest. func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } eventID, ok := parseIDParam(w, r, "id") if !ok { return @@ -47,6 +61,10 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) { if !ok { return } + event, ok := requireEventOwner(w, r, h.events, eventID, hostID) + if !ok { + return + } guest, err := h.guests.Get(r.Context(), guestID) if err != nil { @@ -78,10 +96,307 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) { return } + link := h.invitationLink(raw) + invitationQueued := h.queueInvitation(r.Context(), event, guest, tk, hostID, raw) + writeJSON(w, http.StatusCreated, issueTokenResponse{ - Token: raw, - TokenID: tk.ID, - Meta: tk, + Token: raw, + TokenID: tk.ID, + Meta: tk, + InvitationQueued: invitationQueued, + InvitationLink: link, + }) +} + +// queueInvitation publishes an invitation.send event so the notifier can +// dispatch a branded email. Best-effort: if any step fails we log and +// return false rather than failing the whole token-issue request — the +// host still has the raw URL in the response and can re-trigger sending. +func (h *tokenHandler) queueInvitation( + ctx context.Context, + event *domain.Event, + guest *domain.Guest, + tk *domain.Token, + hostID uuid.UUID, + rawToken string, +) bool { + if h.invitations == nil { + return false + } + if guest.Email == nil || *guest.Email == "" { + // Phone-only / nameless guests get no email — host shares the link + // manually. Show that on the UI so it's not a silent surprise. + return false + } + hostName := "" + if h.users != nil { + if host, err := h.users.GetByID(ctx, hostID); err == nil && host != nil { + hostName = host.Name + } + } + evt := natspub.InvitationSend{ + EventID: event.ID, + GuestID: guest.ID, + TokenID: tk.ID, + GuestName: guest.Name, + GuestEmail: *guest.Email, + HostName: hostName, + EventName: event.Name, + Venue: event.Venue, + EventDate: event.EventDate, + Link: h.invitationLink(rawToken), + IssuedAt: time.Now().UTC(), + } + if err := h.invitations.PublishInvitationSend(ctx, evt); err != nil { + h.logger.Warn("publish invitation.send (continuing)", "err", err, "guest_id", guest.ID) + return false + } + return true +} + +// invitationLink renders the public RSVP URL the guest clicks from their +// inbox. publicBaseURL is the externally-reachable host (set via +// GG_PUBLIC_BASE_URL); access via /rsvp/ is intentional — the +// frontend page rsvp/[token].vue catches the raw token. +func (h *tokenHandler) invitationLink(raw string) string { + base := h.publicBaseURL + if base == "" { + base = "http://localhost:3000" + } + return base + "/rsvp/" + raw +} + +// bulkIssueRequest is the optional JSON body for the bulk-invite call. +// An empty body (or missing GuestIDs) means "every guest on the event +// who doesn't already have a token". +type bulkIssueRequest struct { + GuestIDs []string `json:"guest_ids"` +} + +type bulkIssueItemError struct { + GuestID string `json:"guest_id"` + Reason string `json:"reason"` +} + +// bulkIssueToken is one minted invitation. The raw token is returned so +// the host's UI can offer a "copy link" affordance after a bulk send +// (especially for guests with no email on file) without making another +// round-trip. Same data the per-guest issue endpoint already exposes, +// scoped to the host who owns the event. +type bulkIssueToken struct { + GuestID string `json:"guest_id"` + Token string `json:"token"` + InvitationQueued bool `json:"invitation_queued"` + InvitationLink string `json:"invitation_link"` +} + +type bulkIssueResponse struct { + Issued int `json:"issued"` + Queued int `json:"queued"` + SkippedExisting int `json:"skipped_existing"` + SkippedNoEmail int `json:"skipped_no_email"` + Tokens []bulkIssueToken `json:"tokens,omitempty"` + Errors []bulkIssueItemError `json:"errors,omitempty"` +} + +// POST /events/{id}/guests/invitations/bulk — generate tokens for every +// eligible guest (or the explicit subset) on the event and queue an +// invitation email for those with an address. Best-effort: any per-guest +// error is reported in the response and doesn't abort the rest. +func (h *tokenHandler) bulkIssue(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + event, ok := requireEventOwner(w, r, h.events, eventID, hostID) + if !ok { + return + } + + var req bulkIssueRequest + if r.ContentLength > 0 { + // Body is optional; only decode when something was actually sent. + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + } + + var onlyIDs []uuid.UUID + if len(req.GuestIDs) > 0 { + onlyIDs = make([]uuid.UUID, 0, len(req.GuestIDs)) + for _, raw := range req.GuestIDs { + id, err := uuid.Parse(raw) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid guest id: "+raw) + return + } + onlyIDs = append(onlyIDs, id) + } + } + + guests, err := h.guests.ListGuestsForInvitation(r.Context(), eventID, onlyIDs) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to load guests") + return + } + + hostName := "" + if h.users != nil { + if host, err := h.users.GetByID(r.Context(), hostID); err == nil && host != nil { + hostName = host.Name + } + } + + resp := bulkIssueResponse{} + for _, g := range guests { + if g.HasToken { + resp.SkippedExisting++ + continue + } + raw, hash, err := h.gen.Generate() + if err != nil { + resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "mint token failed"}) + continue + } + tk, err := h.tokens.Create(r.Context(), storage.CreateTokenParams{ + GuestID: g.ID, + TokenHash: hash, + ExpiresAt: time.Now().UTC().Add(h.ttl), + }) + if err != nil { + // Likely a race against the unique constraint (someone else + // issued in parallel) — surface but don't fail the batch. + resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: err.Error()}) + continue + } + resp.Issued++ + link := h.invitationLink(raw) + tokenInfo := bulkIssueToken{ + GuestID: g.ID.String(), + Token: raw, + InvitationLink: link, + } + + if g.Email == "" { + resp.SkippedNoEmail++ + resp.Tokens = append(resp.Tokens, tokenInfo) + continue + } + evt := natspub.InvitationSend{ + EventID: event.ID, + GuestID: g.ID, + TokenID: tk.ID, + GuestName: g.Name, + GuestEmail: g.Email, + HostName: hostName, + EventName: event.Name, + Venue: event.Venue, + EventDate: event.EventDate, + Link: h.invitationLink(raw), + IssuedAt: time.Now().UTC(), + } + if h.invitations == nil { + resp.Tokens = append(resp.Tokens, tokenInfo) + continue + } + if err := h.invitations.PublishInvitationSend(r.Context(), evt); err != nil { + h.logger.Warn("publish invitation.send (bulk, continuing)", "err", err, "guest_id", g.ID) + resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "publish failed"}) + resp.Tokens = append(resp.Tokens, tokenInfo) + continue + } + resp.Queued++ + tokenInfo.InvitationQueued = true + resp.Tokens = append(resp.Tokens, tokenInfo) + } + + writeJSON(w, http.StatusOK, resp) +} + +type rotateTokenRequest struct { + // SendEmail asks the notifier to re-deliver the invitation. False + // means "just give me a fresh link" — typical for phone-only guests + // where the host shares the new URL via SMS. + SendEmail bool `json:"send_email"` +} + +// POST /events/{id}/guests/{guest_id}/tokens/rotate — invalidate the +// guest's existing invitation link and mint a fresh one. Optionally +// re-publishes invitation.send so the notifier re-delivers via email. +// The old URL stops working immediately. +func (h *tokenHandler) rotate(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + eventID, ok := parseIDParam(w, r, "id") + if !ok { + return + } + guestID, ok := parseIDParam(w, r, "guest_id") + if !ok { + return + } + event, ok := requireEventOwner(w, r, h.events, eventID, hostID) + if !ok { + return + } + + guest, err := h.guests.Get(r.Context(), guestID) + if err != nil { + if errors.Is(err, domain.ErrGuestNotFound) { + writeError(w, http.StatusNotFound, "guest not found") + return + } + writeError(w, http.StatusInternalServerError, "failed to load guest") + return + } + if guest.EventID != eventID { + writeError(w, http.StatusNotFound, "guest not found in event") + return + } + + var req rotateTokenRequest + if r.ContentLength > 0 { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + } + + raw, hash, err := h.gen.Generate() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to generate token") + return + } + tk, err := h.tokens.RotateForGuest(r.Context(), storage.CreateTokenParams{ + GuestID: guestID, + TokenHash: hash, + ExpiresAt: time.Now().UTC().Add(h.ttl), + }) + if err != nil { + h.logger.Error("rotate token", "err", err, "guest_id", guestID) + writeError(w, http.StatusInternalServerError, "failed to rotate token") + return + } + + link := h.invitationLink(raw) + invitationQueued := false + if req.SendEmail { + invitationQueued = h.queueInvitation(r.Context(), event, guest, tk, hostID, raw) + } + + writeJSON(w, http.StatusOK, issueTokenResponse{ + Token: raw, + TokenID: tk.ID, + Meta: tk, + InvitationQueued: invitationQueued, + InvitationLink: link, }) } diff --git a/internal/api/unsubscribe.go b/internal/api/unsubscribe.go new file mode 100644 index 0000000..c932c20 --- /dev/null +++ b/internal/api/unsubscribe.go @@ -0,0 +1,50 @@ +package api + +import ( + "log/slog" + "net/http" + + "github.com/alchemistkay/guestguard/internal/notification" +) + +type unsubscribeHandler struct { + logger *slog.Logger + signer *notification.UnsubscribeSigner + suppress *notification.SuppressionRepo +} + +// GET /unsubscribe/{token} — surface the email address that the token +// belongs to so the frontend can show a confirmation page. Honoured even +// before the user clicks "Confirm" so they see what's being unsubscribed. +func (h *unsubscribeHandler) preview(w http.ResponseWriter, r *http.Request) { + if h.signer == nil { + writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured") + return + } + email, err := h.signer.Verify(r.PathValue("token")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid unsubscribe link") + return + } + writeJSON(w, http.StatusOK, map[string]string{"email": email}) +} + +// POST /unsubscribe/{token} — add the email to the suppression list. +// Idempotent: clicking the link twice keeps the existing entry. +func (h *unsubscribeHandler) confirm(w http.ResponseWriter, r *http.Request) { + if h.signer == nil || h.suppress == nil { + writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured") + return + } + email, err := h.signer.Verify(r.PathValue("token")) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid unsubscribe link") + return + } + if err := h.suppress.Add(r.Context(), email, "user clicked unsubscribe", notification.SuppressionUser); err != nil { + h.logger.Error("add suppression", "err", err, "email", email) + writeError(w, http.StatusInternalServerError, "failed to unsubscribe") + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "unsubscribed", "email": email}) +} diff --git a/internal/api/users.go b/internal/api/users.go deleted file mode 100644 index ffea926..0000000 --- a/internal/api/users.go +++ /dev/null @@ -1,55 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "net/http" - "net/mail" - - "github.com/alchemistkay/guestguard/internal/domain" - "github.com/alchemistkay/guestguard/internal/storage" -) - -type userHandler struct { - repo *storage.UserRepo -} - -type upsertUserRequest struct { - Email string `json:"email"` - Name string `json:"name"` -} - -// POST /users — idempotent: returns the existing user if the email already -// exists, creates one otherwise. This keeps the demo flow simple without -// requiring real auth. -func (h *userHandler) upsert(w http.ResponseWriter, r *http.Request) { - var req upsertUserRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(w, http.StatusBadRequest, "invalid json") - return - } - if _, err := mail.ParseAddress(req.Email); err != nil { - writeError(w, http.StatusBadRequest, "email is invalid") - return - } - if req.Name == "" { - writeError(w, http.StatusBadRequest, "name is required") - return - } - - u, err := h.repo.Create(r.Context(), req.Email, req.Name) - if err == nil { - writeJSON(w, http.StatusCreated, u) - return - } - if errors.Is(err, domain.ErrEmailTaken) { - existing, getErr := h.repo.GetByEmail(r.Context(), req.Email) - if getErr != nil { - writeError(w, http.StatusInternalServerError, "failed to load user") - return - } - writeJSON(w, http.StatusOK, existing) - return - } - writeError(w, http.StatusInternalServerError, "failed to create user") -} diff --git a/internal/api/webhooks.go b/internal/api/webhooks.go new file mode 100644 index 0000000..1bcf931 --- /dev/null +++ b/internal/api/webhooks.go @@ -0,0 +1,145 @@ +package api + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + + "github.com/alchemistkay/guestguard/internal/notification" +) + +// webhookHandler accepts provider status notifications and reflects them +// onto the notifications table + suppression list. +// +// Signature verification is intentionally a TODO until the user provisions +// real Twilio + SES creds — verifying against test fixtures alone would +// give a false sense of security. The endpoint is therefore *not* exposed +// publicly until the deployment is ready. +type webhookHandler struct { + logger *slog.Logger + notifs *notification.Repo + suppress *notification.SuppressionRepo +} + +// POST /webhooks/twilio/status — Twilio status callback (form-encoded). +// Fields we care about: MessageSid, MessageStatus (sent|delivered| +// undelivered|failed), ErrorCode, To. +func (h *webhookHandler) twilio(w http.ResponseWriter, r *http.Request) { + if h.notifs == nil { + w.WriteHeader(http.StatusNoContent) + return + } + // TODO(blockD2): verify X-Twilio-Signature with GG_TWILIO_AUTH_TOKEN. + if err := r.ParseForm(); err != nil { + writeError(w, http.StatusBadRequest, "invalid form") + return + } + sid := r.PostForm.Get("MessageSid") + status := r.PostForm.Get("MessageStatus") + if sid == "" || status == "" { + writeError(w, http.StatusBadRequest, "missing MessageSid / MessageStatus") + return + } + + ctx := r.Context() + switch status { + case "delivered": + _ = h.notifs.MarkDelivered(ctx, sid) + case "undelivered", "failed": + _ = h.notifs.MarkBounce(ctx, sid, "permanent") + } + h.logger.Info("twilio status callback", "sid", sid, "status", status) + w.WriteHeader(http.StatusNoContent) +} + +// POST /webhooks/ses/notifications — SNS-delivered SES notification (JSON). +// Handles the two shapes SES uses: bounce + complaint events. Each event +// carries the messageId we stored in provider_message_id and an array of +// affected recipients. +func (h *webhookHandler) ses(w http.ResponseWriter, r *http.Request) { + if h.notifs == nil { + w.WriteHeader(http.StatusNoContent) + return + } + // TODO(blockD2): verify SNS signature using the cert URL field. + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + writeError(w, http.StatusBadRequest, "read body") + return + } + defer r.Body.Close() + + var envelope struct { + Type string `json:"Type"` // "Notification" | "SubscriptionConfirmation" + Message string `json:"Message"` // stringified JSON for Notification + } + if err := json.Unmarshal(body, &envelope); err != nil { + writeError(w, http.StatusBadRequest, "invalid json envelope") + return + } + if envelope.Type == "SubscriptionConfirmation" { + // Confirmed by visiting SubscribeURL — manual op-side step. + h.logger.Info("ses subscription confirmation received (manual confirm required)") + w.WriteHeader(http.StatusNoContent) + return + } + if envelope.Message == "" { + w.WriteHeader(http.StatusNoContent) + return + } + + var inner struct { + NotificationType string `json:"notificationType"` // "Bounce" | "Complaint" | "Delivery" + Mail struct { + MessageID string `json:"messageId"` + } `json:"mail"` + Bounce struct { + BounceType string `json:"bounceType"` // "Permanent" | "Transient" + BouncedRecipients []struct { + EmailAddress string `json:"emailAddress"` + } `json:"bouncedRecipients"` + } `json:"bounce"` + Complaint struct { + ComplainedRecipients []struct { + EmailAddress string `json:"emailAddress"` + } `json:"complainedRecipients"` + } `json:"complaint"` + } + if err := json.Unmarshal([]byte(envelope.Message), &inner); err != nil { + writeError(w, http.StatusBadRequest, "invalid inner json") + return + } + + ctx := r.Context() + switch inner.NotificationType { + case "Bounce": + bt := "transient" + if inner.Bounce.BounceType == "Permanent" { + bt = "permanent" + } + _ = h.notifs.MarkBounce(ctx, inner.Mail.MessageID, bt) + if h.suppress != nil && bt == "permanent" { + for _, rcp := range inner.Bounce.BouncedRecipients { + _ = h.suppress.Add(ctx, rcp.EmailAddress, "ses permanent bounce", notification.SuppressionBounce) + } + } + case "Complaint": + _ = h.notifs.MarkComplaint(ctx, inner.Mail.MessageID) + if h.suppress != nil { + for _, rcp := range inner.Complaint.ComplainedRecipients { + _ = h.suppress.Add(ctx, rcp.EmailAddress, "ses complaint", notification.SuppressionComplaint) + } + } + case "Delivery": + _ = h.notifs.MarkDelivered(ctx, inner.Mail.MessageID) + } + h.logger.Info("ses notification", "type", inner.NotificationType, "message_id", inner.Mail.MessageID) + w.WriteHeader(http.StatusNoContent) +} + +// Compile-time check that ctx is unused in package — silences linter on +// some Go versions when the file would otherwise import context only for +// the handler signatures. +var _ = context.Background diff --git a/internal/api/ws_auth.go b/internal/api/ws_auth.go new file mode 100644 index 0000000..f070ed8 --- /dev/null +++ b/internal/api/ws_auth.go @@ -0,0 +1,51 @@ +package api + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/alchemistkay/guestguard/internal/storage" +) + +type wsTicketHandler struct { + tickets *wsTicketStore + events *storage.EventRepo +} + +type wsTicketResponse struct { + Ticket string `json:"ticket"` + ExpiresAt time.Time `json:"expires_at"` +} + +// POST /auth/ws-ticket — requireAuth-protected; body { "event_id": "" }. +// Returns a single-use ticket valid for ~60 seconds. The frontend appends it +// as `?ticket=…` on the WebSocket URL. +func (h *wsTicketHandler) issue(w http.ResponseWriter, r *http.Request) { + hostID, ok := hostFromContext(w, r) + if !ok { + return + } + + var req struct { + EventID string `json:"event_id"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid json") + return + } + eventID, ok := parseRawUUID(w, "event_id", req.EventID) + if !ok { + return + } + if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok { + return + } + + tok, exp, err := h.tickets.Mint(hostID, eventID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to mint ticket") + return + } + writeJSON(w, http.StatusOK, wsTicketResponse{Ticket: tok, ExpiresAt: exp}) +} diff --git a/internal/api/ws_tickets.go b/internal/api/ws_tickets.go new file mode 100644 index 0000000..c6dbd2a --- /dev/null +++ b/internal/api/ws_tickets.go @@ -0,0 +1,81 @@ +package api + +import ( + "crypto/rand" + "encoding/base64" + "sync" + "time" + + "github.com/google/uuid" +) + +// wsTicketStore mints short-lived single-use tickets that authorise a +// WebSocket handshake. The plan calls this option 3 in Block B: cookies +// don't reach the WS handshake on cross-origin setups and a JWT in the URL +// would leak to logs; a one-shot ticket sidesteps both. +// +// Block B keeps this in-process. When the API runs more than one replica +// this needs to move to Redis (Block C territory). +type wsTicketStore struct { + mu sync.Mutex + entries map[string]wsTicketEntry + ttl time.Duration + now func() time.Time +} + +type wsTicketEntry struct { + userID uuid.UUID + eventID uuid.UUID + expiresAt time.Time +} + +func newWSTicketStore(ttl time.Duration) *wsTicketStore { + return &wsTicketStore{ + entries: make(map[string]wsTicketEntry), + ttl: ttl, + now: time.Now, + } +} + +// Mint returns a fresh URL-safe ticket bound to userID + eventID. +func (s *wsTicketStore) Mint(userID, eventID uuid.UUID) (string, time.Time, error) { + buf := make([]byte, 24) + if _, err := rand.Read(buf); err != nil { + return "", time.Time{}, err + } + tok := base64.RawURLEncoding.EncodeToString(buf) + exp := s.now().Add(s.ttl) + + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked() + s.entries[tok] = wsTicketEntry{userID: userID, eventID: eventID, expiresAt: exp} + return tok, exp, nil +} + +// Consume removes the ticket and returns the bound (userID, eventID) if it +// was valid. A ticket is single-use — replaying it fails. +func (s *wsTicketStore) Consume(token string) (uuid.UUID, uuid.UUID, bool) { + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.entries[token] + if !ok { + return uuid.Nil, uuid.Nil, false + } + delete(s.entries, token) + if s.now().After(entry.expiresAt) { + return uuid.Nil, uuid.Nil, false + } + return entry.userID, entry.eventID, true +} + +// sweepLocked drops expired entries opportunistically. Cheap because we +// usually only hold dozens of tickets at a time. +func (s *wsTicketStore) sweepLocked() { + now := s.now() + for k, v := range s.entries { + if now.After(v.expiresAt) { + delete(s.entries, k) + } + } +} diff --git a/internal/api/wshub.go b/internal/api/wshub.go index 0e2bc18..499708e 100644 --- a/internal/api/wshub.go +++ b/internal/api/wshub.go @@ -89,16 +89,36 @@ func (h *Hub) remove(eventID uuid.UUID, s *subscriber) { } type wsHandler struct { - logger *slog.Logger - hub *Hub + logger *slog.Logger + hub *Hub + tickets *wsTicketStore } -// GET /ws/events/{id} — dashboard live feed for one event. +// GET /ws/events/{id}?ticket=... — dashboard live feed for one event. +// +// The handshake is authorised by a single-use ticket minted via +// POST /auth/ws-ticket (option 3 from the Block B plan). The ticket binds +// the connecting user to a specific event_id; we reject if either is +// missing or doesn't match the URL path. func (h *wsHandler) handle(w http.ResponseWriter, r *http.Request) { eventID, ok := parseIDParam(w, r, "id") if !ok { return } + rawTicket := r.URL.Query().Get("ticket") + if rawTicket == "" { + writeError(w, http.StatusUnauthorized, "missing ticket") + return + } + _, ticketEventID, valid := h.tickets.Consume(rawTicket) + if !valid { + writeError(w, http.StatusUnauthorized, "invalid or expired ticket") + return + } + if ticketEventID != eventID { + writeError(w, http.StatusForbidden, "ticket does not match event") + return + } conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ // In dev the frontend runs on a different origin (localhost:3000 → localhost:8080). diff --git a/internal/auth/email.go b/internal/auth/email.go new file mode 100644 index 0000000..27e8102 --- /dev/null +++ b/internal/auth/email.go @@ -0,0 +1,32 @@ +package auth + +import ( + "context" + "log/slog" +) + +// EmailSender delivers transactional auth emails (verification, reset). +// Block A ships LogSender so dev environments work without Twilio/SES. +// Block D replaces this with a real SES-backed sender. +type EmailSender interface { + SendVerification(ctx context.Context, to, name, link string) error + SendPasswordReset(ctx context.Context, to, name, link string) error +} + +type LogEmailSender struct { + Logger *slog.Logger +} + +func (l LogEmailSender) SendVerification(_ context.Context, to, name, link string) error { + l.Logger.Info("auth email (stub): verification", + "to", to, "name", name, "link", link, + ) + return nil +} + +func (l LogEmailSender) SendPasswordReset(_ context.Context, to, name, link string) error { + l.Logger.Info("auth email (stub): password reset", + "to", to, "name", name, "link", link, + ) + return nil +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..8d12453 --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,95 @@ +package auth + +import ( + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +var ( + ErrInvalidJWT = errors.New("invalid token") + ErrExpiredJWT = errors.New("token expired") +) + +type AccessClaims struct { + UserID uuid.UUID `json:"sub_uuid"` + jwt.RegisteredClaims +} + +type JWTSigner struct { + secret []byte + ttl time.Duration + issuer string + parser *jwt.Parser +} + +func NewJWTSigner(secret string, ttl time.Duration, issuer string) (*JWTSigner, error) { + if len(secret) < 32 { + return nil, fmt.Errorf("jwt secret must be at least 32 bytes") + } + if ttl <= 0 { + return nil, fmt.Errorf("jwt ttl must be positive") + } + return &JWTSigner{ + secret: []byte(secret), + ttl: ttl, + issuer: issuer, + parser: jwt.NewParser( + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}), + jwt.WithIssuer(issuer), + jwt.WithExpirationRequired(), + ), + }, nil +} + +func (s *JWTSigner) Issue(userID uuid.UUID, now time.Time) (string, time.Time, error) { + exp := now.Add(s.ttl) + claims := AccessClaims{ + UserID: userID, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: s.issuer, + Subject: userID.String(), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)), + ExpiresAt: jwt.NewNumericDate(exp), + ID: uuid.NewString(), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := tok.SignedString(s.secret) + if err != nil { + return "", time.Time{}, err + } + return signed, exp, nil +} + +func (s *JWTSigner) Parse(raw string) (*AccessClaims, error) { + claims := &AccessClaims{} + tok, err := s.parser.ParseWithClaims(raw, claims, func(t *jwt.Token) (any, error) { + return s.secret, nil + }) + if err != nil { + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, ErrExpiredJWT + } + return nil, ErrInvalidJWT + } + if !tok.Valid { + return nil, ErrInvalidJWT + } + if claims.UserID == uuid.Nil { + // Fallback for tokens that only carry Subject (defensive — we always + // set UserID on issue). + parsed, perr := uuid.Parse(claims.Subject) + if perr != nil { + return nil, ErrInvalidJWT + } + claims.UserID = parsed + } + return claims, nil +} + +func (s *JWTSigner) TTL() time.Duration { return s.ttl } diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go new file mode 100644 index 0000000..cb4f5f0 --- /dev/null +++ b/internal/auth/jwt_test.go @@ -0,0 +1,82 @@ +package auth + +import ( + "errors" + "testing" + "time" + + "github.com/google/uuid" +) + +const testSecret = "test-secret-must-be-at-least-32-bytes-long-xx" + +func TestJWTRoundTrip(t *testing.T) { + s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test") + if err != nil { + t.Fatalf("signer: %v", err) + } + uid := uuid.New() + tok, exp, err := s.Issue(uid, time.Now()) + if err != nil { + t.Fatalf("issue: %v", err) + } + if tok == "" { + t.Fatal("empty token") + } + if time.Until(exp) <= 0 { + t.Fatalf("expiry in past: %v", exp) + } + claims, err := s.Parse(tok) + if err != nil { + t.Fatalf("parse: %v", err) + } + if claims.UserID != uid { + t.Fatalf("user mismatch: got %s want %s", claims.UserID, uid) + } +} + +func TestJWTExpired(t *testing.T) { + s, err := NewJWTSigner(testSecret, 1*time.Second, "guestguard-test") + if err != nil { + t.Fatalf("signer: %v", err) + } + tok, _, err := s.Issue(uuid.New(), time.Now().Add(-1*time.Hour)) + if err != nil { + t.Fatalf("issue: %v", err) + } + if _, err := s.Parse(tok); !errors.Is(err, ErrExpiredJWT) { + t.Fatalf("expected ErrExpiredJWT, got %v", err) + } +} + +func TestJWTTamper(t *testing.T) { + s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test") + if err != nil { + t.Fatalf("signer: %v", err) + } + tok, _, _ := s.Issue(uuid.New(), time.Now()) + // Flip a character in the signature segment. + tampered := tok[:len(tok)-1] + "a" + if tampered == tok { + tampered = tok[:len(tok)-1] + "b" + } + if _, err := s.Parse(tampered); !errors.Is(err, ErrInvalidJWT) { + t.Fatalf("expected ErrInvalidJWT, got %v", err) + } +} + +func TestJWTSecretTooShort(t *testing.T) { + if _, err := NewJWTSigner("short", time.Minute, "x"); err == nil { + t.Fatal("expected error for short secret") + } +} + +func TestOpaqueTokenHashStable(t *testing.T) { + raw, hash, err := NewOpaqueToken() + if err != nil { + t.Fatalf("mint: %v", err) + } + if got := HashOpaque(raw); got != hash { + t.Fatalf("hash mismatch: got %s want %s", got, hash) + } +} diff --git a/internal/auth/lockout.go b/internal/auth/lockout.go new file mode 100644 index 0000000..a66b49a --- /dev/null +++ b/internal/auth/lockout.go @@ -0,0 +1,107 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +// LockoutTracker counts consecutive failed logins per email and trips a +// lockout flag after a threshold. The lock is keyed by user_id (once we +// know it from the email) so that resetting the password — which we do via +// /auth/reset-password — can clear it cleanly. +// +// Why two keys? The failure counter must work even when the email maps to +// no user (otherwise an attacker probing addresses just gets unlimited +// tries). The lock flag only exists once we've matched an actual account. +type LockoutTracker struct { + client *redis.Client + threshold int + window time.Duration // how long failures linger before counters reset + prefix string +} + +func NewLockoutTracker(client *redis.Client, threshold int, window time.Duration) *LockoutTracker { + return &LockoutTracker{ + client: client, + threshold: threshold, + window: window, + prefix: "auth", + } +} + +func (t *LockoutTracker) failKey(email string) string { + h := sha256.Sum256([]byte(strings.ToLower(strings.TrimSpace(email)))) + return fmt.Sprintf("%s:login_fail:%s", t.prefix, hex.EncodeToString(h[:])) +} + +func (t *LockoutTracker) lockKey(uid uuid.UUID) string { + return fmt.Sprintf("%s:locked:%s", t.prefix, uid.String()) +} + +// IsLocked reports whether the given user's account is currently locked. +func (t *LockoutTracker) IsLocked(ctx context.Context, uid uuid.UUID) (bool, error) { + if t == nil || t.client == nil { + return false, nil + } + v, err := t.client.Exists(ctx, t.lockKey(uid)).Result() + if err != nil { + return false, err + } + return v > 0, nil +} + +// RecordFailure increments the failure counter for the email and, if it +// crosses the threshold, sets the lock flag for the given user id. +// Returns (locked, error). A nil userID is fine — the counter still ticks +// up so probing nonexistent accounts is also rate-limited. +func (t *LockoutTracker) RecordFailure(ctx context.Context, email string, userID *uuid.UUID) (bool, error) { + if t == nil || t.client == nil { + return false, nil + } + key := t.failKey(email) + n, err := t.client.Incr(ctx, key).Result() + if err != nil { + return false, err + } + if n == 1 { + _ = t.client.Expire(ctx, key, t.window).Err() + } + if int(n) >= t.threshold && userID != nil { + // Keep the lock until password reset clears it. 7-day fallback TTL + // so a permanently abandoned account doesn't pile up forever. + if err := t.client.Set(ctx, t.lockKey(*userID), "1", 7*24*time.Hour).Err(); err != nil { + return false, err + } + return true, nil + } + return false, nil +} + +// ClearForUser drops both the lock flag and any in-flight failure counter +// for the user's email. Called from /auth/reset-password. +func (t *LockoutTracker) ClearForUser(ctx context.Context, uid uuid.UUID, email string) error { + if t == nil || t.client == nil { + return nil + } + pipe := t.client.Pipeline() + pipe.Del(ctx, t.lockKey(uid)) + pipe.Del(ctx, t.failKey(email)) + _, err := pipe.Exec(ctx) + return err +} + +// ClearOnSuccess drops only the failure counter — used after a successful +// login to forgive prior typos. +func (t *LockoutTracker) ClearOnSuccess(ctx context.Context, email string) { + if t == nil || t.client == nil { + return + } + _ = t.client.Del(ctx, t.failKey(email)).Err() +} diff --git a/internal/auth/password.go b/internal/auth/password.go new file mode 100644 index 0000000..1cdd051 --- /dev/null +++ b/internal/auth/password.go @@ -0,0 +1,58 @@ +package auth + +import ( + "errors" + + "golang.org/x/crypto/bcrypt" +) + +const bcryptCost = 12 + +var ( + ErrPasswordMismatch = errors.New("password mismatch") + ErrPasswordTooShort = errors.New("password must be at least 8 characters") + ErrPasswordTooLong = errors.New("password must be at most 72 characters") +) + +type PasswordHasher struct { + cost int +} + +func NewPasswordHasher() *PasswordHasher { + return &PasswordHasher{cost: bcryptCost} +} + +func (h *PasswordHasher) Hash(raw string) (string, error) { + if err := ValidatePassword(raw); err != nil { + return "", err + } + b, err := bcrypt.GenerateFromPassword([]byte(raw), h.cost) + if err != nil { + return "", err + } + return string(b), nil +} + +func (h *PasswordHasher) Verify(hash, raw string) error { + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(raw)); err != nil { + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return ErrPasswordMismatch + } + return err + } + return nil +} + +// ValidatePassword enforces a minimum length and the bcrypt-imposed maximum. +// bcrypt silently truncates inputs over 72 bytes, which would let a user set +// a 100-character password and successfully log in with the first 72; reject +// at the boundary instead. +func ValidatePassword(raw string) error { + if len(raw) < 8 { + return ErrPasswordTooShort + } + if len(raw) > 72 { + return ErrPasswordTooLong + } + return nil +} diff --git a/internal/auth/password_test.go b/internal/auth/password_test.go new file mode 100644 index 0000000..cbac34e --- /dev/null +++ b/internal/auth/password_test.go @@ -0,0 +1,44 @@ +package auth + +import ( + "errors" + "strings" + "testing" +) + +func TestPasswordHasherRoundTrip(t *testing.T) { + h := NewPasswordHasher() + hash, err := h.Hash("correct-horse-battery-staple") + if err != nil { + t.Fatalf("hash: %v", err) + } + if hash == "" { + t.Fatal("empty hash") + } + if err := h.Verify(hash, "correct-horse-battery-staple"); err != nil { + t.Fatalf("verify correct: %v", err) + } + if err := h.Verify(hash, "wrong"); !errors.Is(err, ErrPasswordMismatch) { + t.Fatalf("verify wrong: got %v want ErrPasswordMismatch", err) + } +} + +func TestPasswordValidation(t *testing.T) { + tests := []struct { + name string + pw string + want error + }{ + {"too short", "1234567", ErrPasswordTooShort}, + {"min length ok", "12345678", nil}, + {"too long", strings.Repeat("a", 73), ErrPasswordTooLong}, + {"max length ok", strings.Repeat("a", 72), nil}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ValidatePassword(tc.pw); !errors.Is(got, tc.want) { + t.Fatalf("got %v want %v", got, tc.want) + } + }) + } +} diff --git a/internal/auth/secret.go b/internal/auth/secret.go new file mode 100644 index 0000000..23a5038 --- /dev/null +++ b/internal/auth/secret.go @@ -0,0 +1,26 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" +) + +// NewOpaqueToken returns a 32-byte URL-safe random token plus its SHA-256 hex +// digest. The raw value is shown once (in a link); only the digest is stored. +func NewOpaqueToken() (raw, hash string, err error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", "", err + } + raw = base64.RawURLEncoding.EncodeToString(buf) + sum := sha256.Sum256([]byte(raw)) + hash = hex.EncodeToString(sum[:]) + return raw, hash, nil +} + +func HashOpaque(raw string) string { + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/billing/stripe.go b/internal/billing/stripe.go new file mode 100644 index 0000000..ce44a4b --- /dev/null +++ b/internal/billing/stripe.go @@ -0,0 +1,157 @@ +package billing + +import ( + "errors" + "fmt" + + "github.com/stripe/stripe-go/v82" + "github.com/stripe/stripe-go/v82/billingportal/session" + csession "github.com/stripe/stripe-go/v82/checkout/session" + "github.com/stripe/stripe-go/v82/customer" + "github.com/stripe/stripe-go/v82/webhook" +) + +// Client wraps the Stripe SDK with the subset of calls the API needs. +// Concentrates env-var reads + price-ID lookups in one place so handlers +// stay focused on HTTP, not Stripe plumbing. +type Client struct { + secretKey string + webhookSecret string + prices map[Tier]string +} + +// Config is the env-derived configuration the Client needs. +type Config struct { + SecretKey string + WebhookSecret string + PriceProMonthly string + PriceBusiness string +} + +// NewClient validates required fields and returns a configured client. +// Returns (nil, nil) when SecretKey is empty — callers treat that as +// "billing disabled" and degrade gracefully (free tier for everyone, no +// /billing/* endpoints exposed). +func NewClient(cfg Config) (*Client, error) { + if cfg.SecretKey == "" { + return nil, nil + } + stripe.Key = cfg.SecretKey + + c := &Client{ + secretKey: cfg.SecretKey, + webhookSecret: cfg.WebhookSecret, + prices: map[Tier]string{ + TierPro: cfg.PriceProMonthly, + TierBusiness: cfg.PriceBusiness, + }, + } + return c, nil +} + +// Enabled reports whether the client was constructed with a Stripe key. +func (c *Client) Enabled() bool { return c != nil && c.secretKey != "" } + +// PriceFor returns the Stripe Price ID for a tier or an error if it +// hasn't been configured — checkout will fail loudly rather than +// silently send the customer to an empty checkout page. +func (c *Client) PriceFor(tier Tier) (string, error) { + id, ok := c.prices[tier] + if !ok || id == "" { + return "", fmt.Errorf("billing: no Stripe price configured for tier %q", tier) + } + return id, nil +} + +// CreateOrGetCustomer returns the Stripe customer id for the given +// (user_id, email). If `existingID` is non-empty we trust it; otherwise +// we create a new Stripe customer and let the caller persist the id. +func (c *Client) CreateOrGetCustomer(userID, email, name, existingID string) (string, error) { + if !c.Enabled() { + return "", errors.New("billing: disabled") + } + if existingID != "" { + return existingID, nil + } + params := &stripe.CustomerParams{ + Email: stripe.String(email), + Name: stripe.String(name), + Metadata: map[string]string{"gg_user_id": userID}, + } + cust, err := customer.New(params) + if err != nil { + return "", fmt.Errorf("stripe customer create: %w", err) + } + return cust.ID, nil +} + +// CheckoutSessionParams collects the inputs CreateCheckoutSession needs. +// Keeping them in a struct so future fields (coupon codes, trial periods, +// referral metadata) drop in without breaking callers. +type CheckoutSessionParams struct { + CustomerID string + PriceID string + SuccessURL string + CancelURL string +} + +// CreateCheckoutSession returns the URL the frontend redirects the user +// to. Subscription mode — recurring billing for Pro/Business. +func (c *Client) CreateCheckoutSession(p CheckoutSessionParams) (string, error) { + if !c.Enabled() { + return "", errors.New("billing: disabled") + } + params := &stripe.CheckoutSessionParams{ + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + Customer: stripe.String(p.CustomerID), + SuccessURL: stripe.String(p.SuccessURL), + CancelURL: stripe.String(p.CancelURL), + LineItems: []*stripe.CheckoutSessionLineItemParams{{ + Price: stripe.String(p.PriceID), + Quantity: stripe.Int64(1), + }}, + AllowPromotionCodes: stripe.Bool(true), + } + sess, err := csession.New(params) + if err != nil { + return "", fmt.Errorf("stripe checkout session: %w", err) + } + return sess.URL, nil +} + +// CreatePortalSession returns a URL to the Stripe-hosted customer +// portal so the user can manage payment methods, cancel, view invoices. +func (c *Client) CreatePortalSession(customerID, returnURL string) (string, error) { + if !c.Enabled() { + return "", errors.New("billing: disabled") + } + params := &stripe.BillingPortalSessionParams{ + Customer: stripe.String(customerID), + ReturnURL: stripe.String(returnURL), + } + sess, err := session.New(params) + if err != nil { + return "", fmt.Errorf("stripe portal session: %w", err) + } + return sess.URL, nil +} + +// VerifyWebhook validates the Stripe signature header and returns the +// parsed event. Refuses to verify when no webhook secret is configured — +// no shared secret means anyone can POST forged events, so the route +// should reject everything in that case (the caller does that check). +// +// We pass IgnoreAPIVersionMismatch because Stripe accounts can be on a +// newer API version than the SDK we're built against. Event payloads +// are designed to be forward-compatible — the SDK warns about the skew +// but the deserialised event is still safe to use. Strict matching +// would mean we'd have to upgrade the SDK in lockstep with whatever +// Stripe rolls out, defeating the point of having an SDK. +func (c *Client) VerifyWebhook(body []byte, sigHeader string) (stripe.Event, error) { + if c.webhookSecret == "" { + return stripe.Event{}, errors.New("billing: no webhook secret configured") + } + return webhook.ConstructEventWithOptions(body, sigHeader, c.webhookSecret, webhook.ConstructEventOptions{ + IgnoreAPIVersionMismatch: true, + }) +} diff --git a/internal/billing/tiers.go b/internal/billing/tiers.go new file mode 100644 index 0000000..66ca05c --- /dev/null +++ b/internal/billing/tiers.go @@ -0,0 +1,73 @@ +// Package billing models GuestGuard's subscription tiers and Stripe +// integration. Plan limits live here so handler + middleware layers don't +// hard-code numbers — change a value, restart the API, and the cap moves. +package billing + +import "fmt" + +// Tier is the user's current subscription tier. Stored as text in the +// subscriptions table. +type Tier string + +const ( + TierFree Tier = "free" + TierPro Tier = "pro" + TierBusiness Tier = "business" +) + +func (t Tier) Valid() bool { + switch t { + case TierFree, TierPro, TierBusiness: + return true + } + return false +} + +// Limits enforces what a tier may do. -1 means unlimited. +type Limits struct { + EventsPerMonth int + GuestsPerEvent int +} + +// TierLimits is the canonical plan-limits table. Matches docs/TIER1_PLAN.md +// Block F pricing (placeholder until market validation). +var TierLimits = map[Tier]Limits{ + TierFree: {EventsPerMonth: 1, GuestsPerEvent: 50}, + TierPro: {EventsPerMonth: 10, GuestsPerEvent: 1000}, + TierBusiness: {EventsPerMonth: -1, GuestsPerEvent: 5000}, +} + +// LimitsFor returns the limits for a tier, defaulting to Free for unknown +// strings — defensive, so a typo or future-tier in the DB never grants +// unlimited access. +func LimitsFor(t Tier) Limits { + if l, ok := TierLimits[t]; ok { + return l + } + return TierLimits[TierFree] +} + +// StatusGrantsAccess returns true if a subscription with this status +// should be treated as paid. Mirrors the unique-index predicate in +// 0005_billing.up.sql. +func StatusGrantsAccess(status string) bool { + switch status { + case "active", "past_due", "trialing": + return true + } + return false +} + +// LimitError describes a denied-by-policy outcome. The handler layer +// turns this into a 402 with a JSON body the frontend uses to render +// the upgrade modal. +type LimitError struct { + Reason string + Tier Tier + Limit int + Used int +} + +func (e *LimitError) Error() string { + return fmt.Sprintf("billing: %s (tier=%s used=%d limit=%d)", e.Reason, e.Tier, e.Used, e.Limit) +} diff --git a/internal/config/config.go b/internal/config/config.go index c8a4c91..abeb3cc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,11 +12,56 @@ type Config struct { HTTPAddr string DatabaseURL string NATSURL string + RedisAddr string FraudGRPCAddr string FraudGRPCTimeout time.Duration ShutdownTimeout time.Duration TokenSecret string TokenTTL time.Duration + + // Auth + JWTSecret string + JWTIssuer string + AccessTokenTTL time.Duration + RefreshTokenTTL time.Duration + EmailVerificationTTL time.Duration + PasswordResetTTL time.Duration + PublicBaseURL string + RefreshCookieDomain string + RefreshCookieSecure bool + + // Notifications (Block D). Empty values leave the log-stub adapter in + // place so the service still boots without AWS / Twilio creds. + SESRegion string + SESFromEmail string + SESFromName string + SESConfigurationSet string + + SMTPHost string + SMTPPort int + SMTPUsername string + SMTPPassword string + SMTPFromEmail string + SMTPFromName string + SMTPTLS string // "starttls" | "implicit" | "none" (default starttls) + + ResendAPIKey string + ResendFromEmail string + ResendFromName string + + TwilioAccountSID string + TwilioAuthToken string + TwilioFromNumber string + + // Billing (Block F). Empty StripeSecretKey leaves billing disabled — + // all users get free-tier limits, /billing/* returns 503. Lets the + // service boot in dev without Stripe creds. + StripeSecretKey string + StripeWebhookSecret string + StripePricePro string // Stripe Price ID for Pro monthly + StripePriceBusiness string // Stripe Price ID for Business monthly + + UnsubscribeSecret string // HMAC key for signing unsubscribe links } func Load() (*Config, error) { @@ -25,11 +70,50 @@ func Load() (*Config, error) { HTTPAddr: getenv("GG_HTTP_ADDR", ":8080"), DatabaseURL: getenv("GG_DATABASE_URL", "postgres://guestguard:guestguard@localhost:5432/guestguard?sslmode=disable"), NATSURL: getenv("GG_NATS_URL", "nats://localhost:4222"), + RedisAddr: getenv("GG_REDIS_ADDR", "localhost:6379"), FraudGRPCAddr: getenv("GG_FRAUD_GRPC_ADDR", "fraud-engine:9091"), FraudGRPCTimeout: getenvDuration("GG_FRAUD_GRPC_TIMEOUT", 250*time.Millisecond), ShutdownTimeout: getenvDuration("GG_SHUTDOWN_TIMEOUT", 15*time.Second), TokenSecret: os.Getenv("GG_TOKEN_SECRET"), TokenTTL: getenvDuration("GG_TOKEN_TTL", 30*24*time.Hour), + + JWTSecret: os.Getenv("GG_JWT_SECRET"), + JWTIssuer: getenv("GG_JWT_ISSUER", "guestguard"), + AccessTokenTTL: getenvDuration("GG_ACCESS_TOKEN_TTL", 15*time.Minute), + RefreshTokenTTL: getenvDuration("GG_REFRESH_TOKEN_TTL", 30*24*time.Hour), + EmailVerificationTTL: getenvDuration("GG_EMAIL_VERIFICATION_TTL", 24*time.Hour), + PasswordResetTTL: getenvDuration("GG_PASSWORD_RESET_TTL", 1*time.Hour), + PublicBaseURL: getenv("GG_PUBLIC_BASE_URL", "http://localhost:3000"), + RefreshCookieDomain: os.Getenv("GG_REFRESH_COOKIE_DOMAIN"), + RefreshCookieSecure: getenvBool("GG_REFRESH_COOKIE_SECURE", false), + + SESRegion: getenv("GG_SES_REGION", "us-east-1"), + SESFromEmail: os.Getenv("GG_SES_FROM_EMAIL"), + SESFromName: getenv("GG_SES_FROM_NAME", "GuestGuard"), + SESConfigurationSet: os.Getenv("GG_SES_CONFIGURATION_SET"), + + SMTPHost: os.Getenv("GG_SMTP_HOST"), + SMTPPort: getenvInt("GG_SMTP_PORT", 587), + SMTPUsername: os.Getenv("GG_SMTP_USERNAME"), + SMTPPassword: os.Getenv("GG_SMTP_PASSWORD"), + SMTPFromEmail: os.Getenv("GG_SMTP_FROM_EMAIL"), + SMTPFromName: getenv("GG_SMTP_FROM_NAME", "GuestGuard"), + SMTPTLS: getenv("GG_SMTP_TLS", "starttls"), + + ResendAPIKey: os.Getenv("GG_RESEND_API_KEY"), + ResendFromEmail: os.Getenv("GG_RESEND_FROM_EMAIL"), + ResendFromName: getenv("GG_RESEND_FROM_NAME", "GuestGuard"), + + TwilioAccountSID: os.Getenv("GG_TWILIO_ACCOUNT_SID"), + TwilioAuthToken: os.Getenv("GG_TWILIO_AUTH_TOKEN"), + TwilioFromNumber: os.Getenv("GG_TWILIO_FROM_NUMBER"), + + StripeSecretKey: os.Getenv("GG_STRIPE_SECRET_KEY"), + StripeWebhookSecret: os.Getenv("GG_STRIPE_WEBHOOK_SECRET"), + StripePricePro: os.Getenv("GG_STRIPE_PRICE_PRO"), + StripePriceBusiness: os.Getenv("GG_STRIPE_PRICE_BUSINESS"), + + UnsubscribeSecret: os.Getenv("GG_UNSUBSCRIBE_SECRET"), } if cfg.Env == "production" && cfg.TokenSecret == "" { @@ -39,9 +123,52 @@ func Load() (*Config, error) { cfg.TokenSecret = "dev-only-insecure-secret-change-me" } + if cfg.Env == "production" && cfg.JWTSecret == "" { + return nil, fmt.Errorf("GG_JWT_SECRET is required in production") + } + if cfg.JWTSecret == "" { + cfg.JWTSecret = "dev-only-insecure-jwt-secret-change-me-32+bytes" + } + if len(cfg.JWTSecret) < 32 { + return nil, fmt.Errorf("GG_JWT_SECRET must be at least 32 bytes") + } + + if cfg.UnsubscribeSecret == "" { + // Same dev fallback shape as the other secrets — production refuses + // to boot without it. + if cfg.Env == "production" { + return nil, fmt.Errorf("GG_UNSUBSCRIBE_SECRET is required in production") + } + cfg.UnsubscribeSecret = "dev-only-insecure-unsubscribe-secret-change-me" + } + return cfg, nil } +func getenvInt(key string, fallback int) int { + v, ok := os.LookupEnv(key) + if !ok || v == "" { + return fallback + } + n, err := strconv.Atoi(v) + if err != nil { + return fallback + } + return n +} + +func getenvBool(key string, fallback bool) bool { + v, ok := os.LookupEnv(key) + if !ok || v == "" { + return fallback + } + b, err := strconv.ParseBool(v) + if err != nil { + return fallback + } + return b +} + func getenv(key, fallback string) string { if v, ok := os.LookupEnv(key); ok && v != "" { return v diff --git a/internal/csvimport/csvimport.go b/internal/csvimport/csvimport.go new file mode 100644 index 0000000..2ea7c26 --- /dev/null +++ b/internal/csvimport/csvimport.go @@ -0,0 +1,250 @@ +// Package csvimport parses a guest-list CSV into structured rows, with +// tolerant header detection (Excel, Numbers, Google Sheets variants) and +// per-row validation. Streaming-friendly so a 5,000-row import doesn't +// load the entire file into a slice before we know if column 1 is junk. +package csvimport + +import ( + "bufio" + "encoding/csv" + "errors" + "fmt" + "io" + "net/mail" + "regexp" + "strconv" + "strings" + + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +// Row is a single validated guest. Empty Email / Phone are allowed (a +// phone-only or name-only guest is valid per the plan). +type Row struct { + Name string + Email string + Phone string + PlusOnes int +} + +// RowError flags one row with the human-readable reason it can't be +// imported. The line number is 1-based and matches the source CSV +// (header counts as line 1, first data row is line 2) so the frontend +// can highlight the offending row. +type RowError struct { + Row int `json:"row"` + Reason string `json:"reason"` +} + +// Result is the outcome of one parse pass. +type Result struct { + Rows []Row `json:"rows,omitempty"` + Errors []RowError `json:"errors,omitempty"` + TotalCount int `json:"total_count"` // total data rows seen (excluding header) +} + +// Options tune limits + behaviour. +type Options struct { + MaxRows int // hard cap; rows beyond MaxRows return an error instead of being silently dropped +} + +const DefaultMaxRows = 5000 + +// Strict E.164: optional leading +, then a non-zero leading digit (country +// codes never start with 0), followed by 6–14 more digits — total 7–15 +// significant digits. Spaces / dashes / parens are tolerated by stripping +// before validation, but local-format numbers like "0244…" or "07700…" +// are rejected here so the host fixes them at upload time rather than at +// WhatsApp-send time. +var phoneRe = regexp.MustCompile(`^\+?[1-9][0-9]{6,14}$`) + +// Parse reads a CSV from r and returns the parsed result. Encoding is +// auto-detected: UTF-8 with or without BOM, plus UTF-16 LE/BE BOMs +// (commonly produced by Mac Numbers exports). +func Parse(r io.Reader, opt Options) (*Result, error) { + max := opt.MaxRows + if max <= 0 { + max = DefaultMaxRows + } + + rd, err := decodingReader(r) + if err != nil { + return nil, err + } + + csvr := csv.NewReader(rd) + csvr.FieldsPerRecord = -1 // tolerate ragged rows; we re-validate column count ourselves + csvr.TrimLeadingSpace = true + + header, err := csvr.Read() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, errors.New("csv is empty") + } + return nil, fmt.Errorf("read header: %w", err) + } + cols, err := detectColumns(header) + if err != nil { + return nil, err + } + + out := &Result{Rows: make([]Row, 0, 64)} + lineNo := 1 // header was line 1 + for { + rec, err := csvr.Read() + if err == io.EOF { + break + } + lineNo++ + if err != nil { + out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: fmt.Sprintf("malformed csv: %v", err)}) + continue + } + out.TotalCount++ + if out.TotalCount > max { + return nil, fmt.Errorf("import exceeds maximum of %d rows", max) + } + + // Skip fully-empty rows silently — these appear at the end of + // Excel exports a lot. + if rowEmpty(rec) { + out.TotalCount-- // don't count it + continue + } + + row, rerr := buildRow(rec, cols) + if rerr != "" { + out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: rerr}) + continue + } + out.Rows = append(out.Rows, row) + } + return out, nil +} + +func rowEmpty(rec []string) bool { + for _, v := range rec { + if strings.TrimSpace(v) != "" { + return false + } + } + return true +} + +// decodingReader strips a UTF-8 BOM and decodes UTF-16 LE/BE when their +// BOM is present, returning a UTF-8 reader. Other byte orders fall through +// as raw UTF-8. +func decodingReader(r io.Reader) (*bufio.Reader, error) { + br := bufio.NewReader(r) + bom, err := br.Peek(3) + if err != nil && !errors.Is(err, io.EOF) { + return nil, err + } + switch { + case len(bom) >= 3 && bom[0] == 0xEF && bom[1] == 0xBB && bom[2] == 0xBF: + _, _ = br.Discard(3) + return br, nil + case len(bom) >= 2 && bom[0] == 0xFF && bom[1] == 0xFE: + _, _ = br.Discard(2) + dec := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder() + return bufio.NewReader(transform.NewReader(br, dec)), nil + case len(bom) >= 2 && bom[0] == 0xFE && bom[1] == 0xFF: + _, _ = br.Discard(2) + dec := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder() + return bufio.NewReader(transform.NewReader(br, dec)), nil + } + return br, nil +} + +// columnSet records which column index each known field lives in. -1 means +// the column was not supplied; only Name is mandatory. +type columnSet struct { + name, email, phone, plusOnes int +} + +func detectColumns(header []string) (columnSet, error) { + cs := columnSet{name: -1, email: -1, phone: -1, plusOnes: -1} + for i, raw := range header { + key := normaliseHeader(raw) + switch key { + case "name", "guestname", "fullname": + cs.name = i + case "email", "emailaddress", "e-mail": + cs.email = i + case "phone", "telephone", "mobile", "phonenumber": + cs.phone = i + case "plusones", "plus1", "plus-one", "plus-ones", "+1", "guests", "additionalguests": + cs.plusOnes = i + } + } + if cs.name < 0 { + return cs, fmt.Errorf("required column 'name' not found in header: %v", header) + } + return cs, nil +} + +func normaliseHeader(s string) string { + s = strings.ToLower(strings.TrimSpace(s)) + // Drop spaces + underscores. Keep `+`, `-` so "+1" / "plus-one" still + // match exactly. + return strings.NewReplacer(" ", "", "_", "").Replace(s) +} + +func buildRow(rec []string, cs columnSet) (Row, string) { + get := func(i int) string { + if i < 0 || i >= len(rec) { + return "" + } + return strings.TrimSpace(rec[i]) + } + row := Row{ + Name: get(cs.name), + Email: strings.ToLower(get(cs.email)), + Phone: get(cs.phone), + } + if row.Name == "" { + return row, "name is required" + } + + if row.Email != "" { + if _, err := mail.ParseAddress(row.Email); err != nil { + return row, "invalid email" + } + } + if row.Phone != "" { + stripped := stripPhone(row.Phone) + if !phoneRe.MatchString(stripped) { + return row, "phone must be in international format with country code (e.g. +447700900123) — local numbers starting with 0 won't work for SMS or WhatsApp" + } + // Normalise: ensure stored form always starts with "+". + if !strings.HasPrefix(stripped, "+") { + stripped = "+" + stripped + } + row.Phone = stripped + } + if raw := get(cs.plusOnes); raw != "" { + n, err := strconv.Atoi(raw) + if err != nil || n < 0 { + return row, "plus_ones must be a non-negative integer" + } + row.PlusOnes = n + } + return row, "" +} + +var phoneStripper = strings.NewReplacer(" ", "", "-", "", "(", "", ")", "", " ", "") + +func stripPhone(s string) string { + return phoneStripper.Replace(s) +} + +// TemplateCSV is the sample file served at /events/{id}/guests/import/template. +// Phone numbers MUST include the country code (e.g. +44 for UK, +233 for +// Ghana). Local-format numbers like "0244..." or "07700..." will be +// rejected at upload — the sample below shows the expected shape. +const TemplateCSV = "name,email,phone,plus_ones\n" + + "Alex Doe,alex@example.com,+447700900123,1\n" + + "Sam Patel,sam@example.com,,0\n" + + "Jordan Lee,,+15551234567,2\n" + + "Mira Patel,mira@example.com,+233244123456,0\n" diff --git a/internal/csvimport/csvimport_test.go b/internal/csvimport/csvimport_test.go new file mode 100644 index 0000000..2c04b1c --- /dev/null +++ b/internal/csvimport/csvimport_test.go @@ -0,0 +1,157 @@ +package csvimport + +import ( + "strings" + "testing" +) + +func TestParseHappyPath(t *testing.T) { + in := `name,email,phone,plus_ones +Alex Doe,alex@example.com,+447700900123,1 +Sam Patel,SAM@example.com,,0 +Jordan Lee,,+1 (555) 123-4567,2 +` + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if got, want := len(r.Rows), 3; got != want { + t.Fatalf("rows: got %d want %d (errors=%+v)", got, want, r.Errors) + } + if r.Rows[1].Email != "sam@example.com" { + t.Errorf("email not lowercased: %q", r.Rows[1].Email) + } + if r.Rows[2].Phone != "+15551234567" { + t.Errorf("phone not stripped: %q", r.Rows[2].Phone) + } + if r.Rows[0].Phone != "+447700900123" { + t.Errorf("phone should keep leading +: %q", r.Rows[0].Phone) + } +} + +func TestParsePhoneNormalisedToPlus(t *testing.T) { + // E.164 without explicit "+" is accepted and normalised to include one. + in := "name,phone\nAlex,447700900123\n" + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(r.Rows) != 1 || r.Rows[0].Phone != "+447700900123" { + t.Fatalf("expected normalised phone, got: rows=%+v errors=%+v", r.Rows, r.Errors) + } +} + +func TestParsePhoneRejectsLocalFormat(t *testing.T) { + // Local UK / GH style numbers (leading 0, no country code) must be + // rejected — they break SMS routing and WhatsApp click-to-chat. + in := `name,phone +UK Local,07700900123 +GH Local,0244123456 +With Plus And Zero,+0244123456 +` + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(r.Errors) != 3 { + t.Fatalf("expected 3 errors for local-format phones, got %d: %+v", len(r.Errors), r.Errors) + } +} + +func TestParseHeaderVariants(t *testing.T) { + cases := []string{ + "Name,Email,Phone,Plus Ones\nMira,m@x.com,,1\n", + "Guest Name,E-Mail,Telephone,+1\nMira,m@x.com,,1\n", + "full_name,email_address,mobile,plusones\nMira,m@x.com,,1\n", + } + for i, in := range cases { + t.Run(string(rune('a'+i)), func(t *testing.T) { + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(r.Rows) != 1 { + t.Fatalf("rows=%d errors=%+v", len(r.Rows), r.Errors) + } + if r.Rows[0].PlusOnes != 1 { + t.Errorf("plusones not detected: %+v", r.Rows[0]) + } + }) + } +} + +func TestParseRowValidation(t *testing.T) { + in := `name,email,phone,plus_ones +,a@x.com,,1 +Valid Guest,not-an-email,,0 +Phone Person,,abc,0 +Negative,,,-1 +Email Only,e@x.com,, +` + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(r.Errors) != 4 { + t.Fatalf("expected 4 errors, got %d: %+v", len(r.Errors), r.Errors) + } + // "Email Only" row is valid: name present, email parses, phone+plus blank. + if len(r.Rows) != 1 || r.Rows[0].Name != "Email Only" { + t.Fatalf("expected 1 valid row, got %+v", r.Rows) + } +} + +func TestParseUTF8BOM(t *testing.T) { + in := "\xEF\xBB\xBFname,email\nMira,m@x.com\n" + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" { + t.Fatalf("BOM not stripped: %+v", r.Rows) + } +} + +func TestParseUTF16LE(t *testing.T) { + // "name,email\nMira,m@x.com\n" in UTF-16 LE with BOM. + in := []byte{0xFF, 0xFE} + for _, r := range "name,email\nMira,m@x.com\n" { + in = append(in, byte(r), byte(r>>8)) + } + r, err := Parse(strings.NewReader(string(in)), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" { + t.Fatalf("UTF-16 LE not decoded: %+v", r.Rows) + } +} + +func TestParseEmptyTrailingRows(t *testing.T) { + in := "name,email\nMira,m@x.com\n,,\n\n,\n" + r, err := Parse(strings.NewReader(in), Options{}) + if err != nil { + t.Fatalf("parse: %v", err) + } + if r.TotalCount != 1 { + t.Fatalf("trailing blanks counted: TotalCount=%d", r.TotalCount) + } +} + +func TestParseMissingNameHeader(t *testing.T) { + in := "email,phone\na@x.com,\n" + if _, err := Parse(strings.NewReader(in), Options{}); err == nil { + t.Fatal("expected error for missing name column") + } +} + +func TestParseMaxRows(t *testing.T) { + var b strings.Builder + b.WriteString("name\n") + for i := 0; i < 11; i++ { + b.WriteString("X\n") + } + if _, err := Parse(strings.NewReader(b.String()), Options{MaxRows: 10}); err == nil { + t.Fatal("expected error when exceeding MaxRows") + } +} diff --git a/internal/domain/user.go b/internal/domain/user.go index c8fce8f..a461c70 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -8,14 +8,33 @@ import ( ) type User struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + PasswordHash string `json:"-"` + EmailVerified bool `json:"email_verified"` + EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"` + DeletedAt *time.Time `json:"-"` + TermsAcceptedAt *time.Time `json:"terms_accepted_at,omitempty"` + PrivacyPolicyAcceptedAt *time.Time `json:"privacy_policy_accepted_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TermsAccepted reports whether the user has accepted both the terms +// of service and the privacy policy. Both must be present for the user +// to use the dashboard once enforcement is enabled. +func (u *User) TermsAccepted() bool { + return u != nil && u.TermsAcceptedAt != nil && u.PrivacyPolicyAcceptedAt != nil } var ( - ErrUserNotFound = errors.New("user not found") - ErrEmailTaken = errors.New("email already in use") + ErrUserNotFound = errors.New("user not found") + ErrEmailTaken = errors.New("email already in use") + ErrEmailNotVerified = errors.New("email not verified") + ErrAuthTokenNotFound = errors.New("auth token not found") + ErrAuthTokenConsumed = errors.New("auth token already used") + ErrAuthTokenExpired = errors.New("auth token expired") + ErrRefreshTokenRevoked = errors.New("refresh token revoked") + ErrAccountLocked = errors.New("account locked due to too many failed login attempts") ) diff --git a/internal/natspub/client.go b/internal/natspub/client.go index 6fb6f1f..680cd07 100644 --- a/internal/natspub/client.go +++ b/internal/natspub/client.go @@ -92,6 +92,15 @@ func (c *Client) PublishRSVPConfirmed(ctx context.Context, evt RSVPConfirmed) er return c.publishJSON(ctx, SubjectRSVPConfirmed, evt, evt.RSVPID) } +func (c *Client) PublishInvitationSend(ctx context.Context, evt InvitationSend) error { + if evt.IssuedAt.IsZero() { + evt.IssuedAt = time.Now().UTC() + } + // Dedup by token id — re-issuing a token (currently disallowed by the + // unique constraint, but defensive) won't double-send the email. + return c.publishJSON(ctx, SubjectInvitationSend, evt, evt.TokenID) +} + func (c *Client) publishJSON(ctx context.Context, subject string, payload any, dedupeKey uuid.UUID) error { body, err := json.Marshal(payload) if err != nil { diff --git a/internal/natspub/events.go b/internal/natspub/events.go index 9d6b04e..eb7fc46 100644 --- a/internal/natspub/events.go +++ b/internal/natspub/events.go @@ -38,3 +38,20 @@ type RSVPConfirmed struct { RiskScore *int `json:"risk_score,omitempty"` SubmittedAt time.Time `json:"submitted_at"` } + +// InvitationSend asks the notifier to dispatch a guest invitation email. +// Carries everything the email template needs so the worker doesn't have +// to re-fetch event/guest details from Postgres on every send. +type InvitationSend struct { + EventID uuid.UUID `json:"event_id"` + GuestID uuid.UUID `json:"guest_id"` + TokenID uuid.UUID `json:"token_id"` + GuestName string `json:"guest_name"` + GuestEmail string `json:"guest_email"` + HostName string `json:"host_name"` + EventName string `json:"event_name"` + Venue string `json:"venue,omitempty"` + EventDate time.Time `json:"event_date"` + Link string `json:"link"` + IssuedAt time.Time `json:"issued_at"` +} diff --git a/internal/natspub/invitation_subscriber.go b/internal/natspub/invitation_subscriber.go new file mode 100644 index 0000000..4f943ff --- /dev/null +++ b/internal/natspub/invitation_subscriber.go @@ -0,0 +1,64 @@ +package natspub + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "time" + + "github.com/nats-io/nats.go/jetstream" +) + +type InvitationSendHandler func(ctx context.Context, evt InvitationSend) error + +type InvitationSendSubscriber struct { + logger *slog.Logger + consumer jetstream.Consumer + handler InvitationSendHandler +} + +func NewInvitationSendSubscriber( + ctx context.Context, + c *Client, + durable string, + handler InvitationSendHandler, + logger *slog.Logger, +) (*InvitationSendSubscriber, error) { + cons, err := c.js.CreateOrUpdateConsumer(ctx, StreamName, jetstream.ConsumerConfig{ + Durable: durable, + Name: durable, + FilterSubject: SubjectInvitationSend, + AckPolicy: jetstream.AckExplicitPolicy, + DeliverPolicy: jetstream.DeliverAllPolicy, + MaxDeliver: 5, + AckWait: 30 * time.Second, + }) + if err != nil { + return nil, fmt.Errorf("create consumer %s: %w", durable, err) + } + return &InvitationSendSubscriber{logger: logger, consumer: cons, handler: handler}, nil +} + +func (s *InvitationSendSubscriber) Start(ctx context.Context) (jetstream.ConsumeContext, error) { + cc, err := s.consumer.Consume(func(msg jetstream.Msg) { + var evt InvitationSend + if err := json.Unmarshal(msg.Data(), &evt); err != nil { + s.logger.Error("decode invitation.send", "err", err) + _ = msg.Term() + return + } + hctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if err := s.handler(hctx, evt); err != nil { + s.logger.Error("handle invitation.send", "err", err) + _ = msg.NakWithDelay(5 * time.Second) + return + } + _ = msg.Ack() + }) + if err != nil { + return nil, fmt.Errorf("consume: %w", err) + } + return cc, nil +} diff --git a/internal/notification/email_ses.go b/internal/notification/email_ses.go new file mode 100644 index 0000000..c2ee207 --- /dev/null +++ b/internal/notification/email_ses.go @@ -0,0 +1,137 @@ +package notification + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sesv2" + "github.com/aws/aws-sdk-go-v2/service/sesv2/types" +) + +// SESConfig is the surface area for picking an SES sender. ConfigurationSet +// is optional but recommended in production — it's where bounce + complaint +// SNS topics get wired so the webhook handler has something to consume. +type SESConfig struct { + Region string + FromEmail string + FromName string + ConfigurationSet string + PublicBaseURL string // for unsubscribe links in templates +} + +// SESEmailSender sends transactional emails (verification + reset for the +// auth flows, plus invitation/confirmation/reminder for guests) via Amazon +// SESv2. The same client serves both audiences so callers don't end up +// with two SES configurations to maintain. +type SESEmailSender struct { + client *sesv2.Client + tpls *Templates + from string + configSet *string + baseURL string +} + +// NewSESEmailSender returns a configured SES sender, or an error if the +// AWS SDK can't bootstrap. The caller typically constructs this once at +// startup and reuses it for the lifetime of the process. +func NewSESEmailSender(ctx context.Context, cfg SESConfig, tpls *Templates) (*SESEmailSender, error) { + if cfg.FromEmail == "" { + return nil, fmt.Errorf("ses: FromEmail required") + } + if cfg.Region == "" { + cfg.Region = "us-east-1" + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(cfg.Region)) + if err != nil { + return nil, fmt.Errorf("ses: load aws config: %w", err) + } + client := sesv2.NewFromConfig(awsCfg) + + from := cfg.FromEmail + if cfg.FromName != "" { + from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail) + } + var cs *string + if cfg.ConfigurationSet != "" { + cs = aws.String(cfg.ConfigurationSet) + } + + return &SESEmailSender{ + client: client, + tpls: tpls, + from: from, + configSet: cs, + baseURL: strings.TrimRight(cfg.PublicBaseURL, "/"), + }, nil +} + +// --- auth.EmailSender implementation --- + +// SendVerification renders the verification template and posts it to SES. +func (s *SESEmailSender) SendVerification(ctx context.Context, to, name, link string) error { + return s.sendTemplated(ctx, to, "Verify your GuestGuard email", + TmplVerification, map[string]any{ + "Name": name, + "Link": link, + }) +} + +// SendPasswordReset renders the reset template and posts it to SES. +func (s *SESEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error { + return s.sendTemplated(ctx, to, "Reset your GuestGuard password", + TmplPasswordReset, map[string]any{ + "Name": name, + "Link": link, + "ExpiryHumane": "1 hour", + }) +} + +// SendGuest is used by the notifier worker for invitation / confirmation / +// reminder emails — anything addressed at a guest. +func (s *SESEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error) { + return s.sendTemplatedReturnID(ctx, to, subject, name, data) +} + +// --- internals --- + +func (s *SESEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error { + _, err := s.sendTemplatedReturnID(ctx, to, subject, name, data) + return err +} + +func (s *SESEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + if data == nil { + data = map[string]any{} + } + data["Subject"] = subject + html, text, err := s.tpls.Render(name, data) + if err != nil { + return "", err + } + + out, err := s.client.SendEmail(ctx, &sesv2.SendEmailInput{ + FromEmailAddress: aws.String(s.from), + Destination: &types.Destination{ToAddresses: []string{to}}, + ConfigurationSetName: s.configSet, + Content: &types.EmailContent{ + Simple: &types.Message{ + Subject: &types.Content{Data: aws.String(subject), Charset: aws.String("UTF-8")}, + Body: &types.Body{ + Html: &types.Content{Data: aws.String(html), Charset: aws.String("UTF-8")}, + Text: &types.Content{Data: aws.String(text), Charset: aws.String("UTF-8")}, + }, + }, + }, + }) + if err != nil { + return "", fmt.Errorf("ses: send: %w", err) + } + if out.MessageId == nil { + return "", nil + } + return *out.MessageId, nil +} diff --git a/internal/notification/factory.go b/internal/notification/factory.go new file mode 100644 index 0000000..dbdff80 --- /dev/null +++ b/internal/notification/factory.go @@ -0,0 +1,97 @@ +package notification + +import ( + "context" + "log/slog" +) + +// EmailBackend names the chosen email delivery channel for telemetry + +// startup logging. Mostly a debugging aid — code paths don't branch on +// this value. +type EmailBackend string + +const ( + BackendResend EmailBackend = "resend" + BackendSMTP EmailBackend = "smtp" + BackendSES EmailBackend = "ses" + BackendLog EmailBackend = "log" +) + +// EmailSenderConfig collects every email-related env var so the picker +// has a single, ordered place to decide which backend wins. Priority is +// Resend > SMTP > SES > Log — the first one with non-empty creds is used. +type EmailSenderConfig struct { + Resend ResendConfig + SMTP SMTPConfig + SES SESConfig +} + +// CombinedEmailSender satisfies both the auth.EmailSender interface (for +// verification + reset emails) and GuestEmailDispatcher (for invitation, +// confirmation, reminder). One concrete value handles both audiences so +// callers don't end up with two configurations. +type CombinedEmailSender interface { + SendVerification(ctx context.Context, to, name, link string) error + SendPasswordReset(ctx context.Context, to, name, link string) error + SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) +} + +// PickEmailSender returns the configured email sender + which backend was +// chosen. Falls back to a logger stub if nothing is configured, so the +// service stays bootable in stripped-down dev environments. +func PickEmailSender(ctx context.Context, cfg EmailSenderConfig, tpls *Templates, logger *slog.Logger) (CombinedEmailSender, EmailBackend, error) { + switch { + case cfg.Resend.APIKey != "": + s, err := NewResendEmailSender(cfg.Resend, tpls) + if err != nil { + return nil, "", err + } + return s, BackendResend, nil + case cfg.SMTP.Host != "": + s, err := NewSMTPEmailSender(cfg.SMTP, tpls) + if err != nil { + return nil, "", err + } + return s, BackendSMTP, nil + case cfg.SES.FromEmail != "": + s, err := NewSESEmailSender(ctx, cfg.SES, tpls) + if err != nil { + return nil, "", err + } + return s, BackendSES, nil + } + return &logCombinedSender{logger: logger, tpls: tpls}, BackendLog, nil +} + +// logCombinedSender is the dev fallback. Verification + reset emails come +// through as structured log lines (preserving the Block A behaviour); +// guest emails get rendered + dumped so engineers can eyeball the output. +type logCombinedSender struct { + logger *slog.Logger + tpls *Templates +} + +func (l *logCombinedSender) SendVerification(_ context.Context, to, name, link string) error { + l.logger.Info("auth email (stub): verification", "to", to, "name", name, "link", link) + return nil +} + +func (l *logCombinedSender) SendPasswordReset(_ context.Context, to, name, link string) error { + l.logger.Info("auth email (stub): password reset", "to", to, "name", name, "link", link) + return nil +} + +func (l *logCombinedSender) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + if data == nil { + data = map[string]any{} + } + data["Subject"] = subject + _, text, err := l.tpls.Render(name, data) + if err != nil { + return "", err + } + l.logger.Info("guest email (stub)", + "to", to, "subject", subject, "template", string(name), "text_body", text, + ) + return "log:" + string(name), nil +} diff --git a/internal/notification/notification.go b/internal/notification/notification.go index 270a0eb..c36f3ba 100644 --- a/internal/notification/notification.go +++ b/internal/notification/notification.go @@ -64,12 +64,13 @@ func NewRepo(db *storage.DB) *Repo { } type RecordParams struct { - GuestID uuid.UUID - Channel Channel - Type Type - Status Status - ProviderID string - Error string + GuestID uuid.UUID + Channel Channel + Type Type + Status Status + ProviderID string // human-friendly id (e.g. "log:xyz") + ProviderMessageID string // provider's message id (Twilio SID, SES MessageId) + Error string } func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) { @@ -77,6 +78,10 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) { if p.ProviderID != "" { providerID = &p.ProviderID } + var providerMsgID *string + if p.ProviderMessageID != "" { + providerMsgID = &p.ProviderMessageID + } var errStr *string if p.Error != "" { errStr = &p.Error @@ -90,14 +95,15 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) { const q = ` INSERT INTO notifications (guest_id, channel, type, status, provider_id, - attempts, last_attempt, delivered_at, error) - VALUES ($1, $2, $3, $4, $5, 1, now(), $6, $7) + provider_message_id, attempts, last_attempt, + delivered_at, error) + VALUES ($1, $2, $3, $4, $5, $6, 1, now(), $7, $8) RETURNING id ` var id uuid.UUID err := r.pool.QueryRow(ctx, q, p.GuestID, string(p.Channel), string(p.Type), string(p.Status), - providerID, deliveredAt, errStr, + providerID, providerMsgID, deliveredAt, errStr, ).Scan(&id) if err != nil { return uuid.Nil, fmt.Errorf("record notification: %w", err) @@ -105,6 +111,35 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) { return id, nil } +// MarkBounce records a bounce on the notification row identified by the +// provider's message id. Called from webhook handlers. +func (r *Repo) MarkBounce(ctx context.Context, providerMessageID, bounceType string) error { + _, err := r.pool.Exec(ctx, ` + UPDATE notifications + SET status = 'bounced', bounce_type = $2, error = COALESCE(error, '') + WHERE provider_message_id = $1 + `, providerMessageID, bounceType) + return err +} + +// MarkComplaint records a complaint (spam report) for the same row. +func (r *Repo) MarkComplaint(ctx context.Context, providerMessageID string) error { + _, err := r.pool.Exec(ctx, ` + UPDATE notifications SET complained = TRUE WHERE provider_message_id = $1 + `, providerMessageID) + return err +} + +// MarkDelivered moves a row from 'sent' to 'delivered' when the provider's +// delivery status webhook fires. +func (r *Repo) MarkDelivered(ctx context.Context, providerMessageID string) error { + _, err := r.pool.Exec(ctx, ` + UPDATE notifications SET status = 'delivered', delivered_at = now() + WHERE provider_message_id = $1 AND status NOT IN ('bounced','failed') + `, providerMessageID) + return err +} + // LogSender pretends to send and just logs. Useful for Phase 3 demos and // tests; concrete providers (Twilio/SES) plug in later. type LogSender struct{} diff --git a/internal/notification/resend_sender.go b/internal/notification/resend_sender.go new file mode 100644 index 0000000..6873205 --- /dev/null +++ b/internal/notification/resend_sender.go @@ -0,0 +1,134 @@ +package notification + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" +) + +// ResendConfig configures the Resend HTTP sender. APIKey is required; +// FromEmail is the sending address on a domain you've verified in the +// Resend dashboard. +type ResendConfig struct { + APIKey string + FromEmail string + FromName string + + // HTTPClient overrides the default client (mostly so tests can point + // at httptest.Server). Leave nil for production. + HTTPClient *http.Client + BaseURL string // overrideable for tests; defaults to https://api.resend.com +} + +// ResendEmailSender posts emails through https://api.resend.com/emails. +// Implements auth.EmailSender + GuestEmailDispatcher. +type ResendEmailSender struct { + cfg ResendConfig + tpls *Templates + from string + client *http.Client + url string +} + +func NewResendEmailSender(cfg ResendConfig, tpls *Templates) (*ResendEmailSender, error) { + if cfg.APIKey == "" { + return nil, errors.New("resend: APIKey required") + } + if cfg.FromEmail == "" { + return nil, errors.New("resend: FromEmail required") + } + from := cfg.FromEmail + if cfg.FromName != "" { + from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail) + } + cli := cfg.HTTPClient + if cli == nil { + cli = &http.Client{Timeout: 15 * time.Second} + } + base := cfg.BaseURL + if base == "" { + base = "https://api.resend.com" + } + return &ResendEmailSender{cfg: cfg, tpls: tpls, from: from, client: cli, url: base + "/emails"}, nil +} + +// --- auth.EmailSender --- + +func (s *ResendEmailSender) SendVerification(ctx context.Context, to, name, link string) error { + _, err := s.sendTemplated(ctx, to, "Verify your GuestGuard email", + TmplVerification, map[string]any{"Name": name, "Link": link}) + return err +} + +func (s *ResendEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error { + _, err := s.sendTemplated(ctx, to, "Reset your GuestGuard password", + TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"}) + return err +} + +// --- GuestEmailDispatcher --- + +func (s *ResendEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + return s.sendTemplated(ctx, to, subject, name, data) +} + +// --- internals --- + +type resendRequest struct { + From string `json:"from"` + To []string `json:"to"` + Subject string `json:"subject"` + HTML string `json:"html"` + Text string `json:"text"` +} + +type resendResponse struct { + ID string `json:"id"` + Message string `json:"message,omitempty"` +} + +func (s *ResendEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + if data == nil { + data = map[string]any{} + } + data["Subject"] = subject + html, text, err := s.tpls.Render(name, data) + if err != nil { + return "", err + } + + body, _ := json.Marshal(resendRequest{ + From: s.from, + To: []string{to}, + Subject: subject, + HTML: html, + Text: text, + }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+s.cfg.APIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return "", fmt.Errorf("resend: do: %w", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if resp.StatusCode >= 300 { + return "", fmt.Errorf("resend: status %d: %s", resp.StatusCode, string(respBody)) + } + var parsed resendResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + return "", fmt.Errorf("resend: parse: %w", err) + } + return parsed.ID, nil +} diff --git a/internal/notification/resend_sender_test.go b/internal/notification/resend_sender_test.go new file mode 100644 index 0000000..c1078db --- /dev/null +++ b/internal/notification/resend_sender_test.go @@ -0,0 +1,89 @@ +package notification + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestResendSenderHappyPath(t *testing.T) { + tpls, err := NewTemplates() + if err != nil { + t.Fatalf("NewTemplates: %v", err) + } + + var gotPath, gotAuth, gotContentType string + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotContentType = r.Header.Get("Content-Type") + b, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(b, &gotBody) + _ = json.NewEncoder(w).Encode(map[string]string{"id": "resend-abc-123"}) + })) + t.Cleanup(srv.Close) + + s, err := NewResendEmailSender(ResendConfig{ + APIKey: "secret-key", + FromEmail: "no-reply@example.test", + FromName: "GuestGuard", + BaseURL: srv.URL, + }, tpls) + if err != nil { + t.Fatalf("NewResendEmailSender: %v", err) + } + + id, err := s.SendGuest(context.Background(), "to@example.test", "You're invited", + TmplInvitation, map[string]any{ + "GuestName": "Mira", "HostName": "Kay", "EventName": "Beach Day", + "Link": "https://example.test/rsvp/x", + }) + if err != nil { + t.Fatalf("SendGuest: %v", err) + } + if id != "resend-abc-123" { + t.Fatalf("provider id: got %q want %q", id, "resend-abc-123") + } + if gotPath != "/emails" { + t.Errorf("path: got %q want /emails", gotPath) + } + if gotAuth != "Bearer secret-key" { + t.Errorf("auth: got %q", gotAuth) + } + if gotContentType != "application/json" { + t.Errorf("content-type: got %q", gotContentType) + } + if gotBody["from"] != "GuestGuard " { + t.Errorf("from: got %v", gotBody["from"]) + } + if !strings.Contains(gotBody["html"].(string), "Beach Day") { + t.Errorf("html missing event name: %v", gotBody["html"]) + } + if !strings.Contains(gotBody["text"].(string), "Mira") { + t.Errorf("text missing guest name: %v", gotBody["text"]) + } +} + +func TestResendSenderErrorPropagates(t *testing.T) { + tpls, _ := NewTemplates() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message":"invalid api key"}`)) + })) + t.Cleanup(srv.Close) + + s, err := NewResendEmailSender(ResendConfig{ + APIKey: "bad", FromEmail: "x@y.test", BaseURL: srv.URL, + }, tpls) + if err != nil { + t.Fatalf("NewResendEmailSender: %v", err) + } + if err := s.SendVerification(context.Background(), "z@y.test", "Z", "https://link"); err == nil { + t.Fatal("expected error on 401") + } +} diff --git a/internal/notification/router.go b/internal/notification/router.go new file mode 100644 index 0000000..7feb32e --- /dev/null +++ b/internal/notification/router.go @@ -0,0 +1,70 @@ +package notification + +import ( + "context" + "fmt" + "log/slog" + + "github.com/google/uuid" +) + +// Router dispatches an OutboundMessage to the channel-appropriate sender. +// The notifier worker holds one Router and stays oblivious to whether the +// concrete backend is SES, Twilio, or a logger stub. +type Router struct { + email Sender + sms Sender +} + +func NewRouter(email, sms Sender) *Router { + return &Router{email: email, sms: sms} +} + +func (r *Router) Send(ctx context.Context, msg OutboundMessage) (string, error) { + switch msg.Channel { + case ChannelEmail: + if r.email == nil { + return "", fmt.Errorf("no email sender configured") + } + return r.email.Send(ctx, msg) + case ChannelSMS: + if r.sms == nil { + return "", fmt.Errorf("no sms sender configured") + } + return r.sms.Send(ctx, msg) + } + return "", fmt.Errorf("router: unknown channel %q", msg.Channel) +} + +// LogGuestEmailDispatcher is the dev-mode dispatcher that renders the +// templated email and logs both bodies. Useful for local docker-compose +// before SES is configured — gives engineers the rendered output without +// needing a real inbox. +type LogGuestEmailDispatcher struct { + logger *slog.Logger + tpls *Templates +} + +func NewLogGuestEmailDispatcher(logger *slog.Logger, tpls *Templates) *LogGuestEmailDispatcher { + return &LogGuestEmailDispatcher{logger: logger, tpls: tpls} +} + +func (d *LogGuestEmailDispatcher) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + if data == nil { + data = map[string]any{} + } + data["Subject"] = subject + _, text, err := d.tpls.Render(name, data) + if err != nil { + return "", err + } + id := "log:" + uuid.NewString() + d.logger.Info("guest email (stub)", + "to", to, + "subject", subject, + "template", string(name), + "provider_message_id", id, + "text_body", text, + ) + return id, nil +} diff --git a/internal/notification/sender_email.go b/internal/notification/sender_email.go new file mode 100644 index 0000000..774cb4d --- /dev/null +++ b/internal/notification/sender_email.go @@ -0,0 +1,104 @@ +package notification + +import ( + "context" + "errors" + "fmt" + + "github.com/google/uuid" +) + +// GuestEmailDispatcher is the abstraction the notifier uses to send a +// templated email to a single guest. Both SES (production) and Log (dev) +// satisfy it. +type GuestEmailDispatcher interface { + SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error) +} + +// EmailSender is the notification.Sender for ChannelEmail. It maps the +// generic OutboundMessage shape used by the notifier worker onto our +// templated SES / Log path, and honours the suppression list (a guest who +// unsubscribed must never receive another email). +type EmailSender struct { + dispatcher GuestEmailDispatcher + suppression *SuppressionRepo +} + +func NewEmailSender(d GuestEmailDispatcher, s *SuppressionRepo) *EmailSender { + return &EmailSender{dispatcher: d, suppression: s} +} + +func (e *EmailSender) Send(ctx context.Context, msg OutboundMessage) (string, error) { + if msg.GuestID == uuid.Nil { + return "", errors.New("missing guest id") + } + if msg.Channel != ChannelEmail { + return "", fmt.Errorf("EmailSender does not handle channel %q", msg.Channel) + } + to, _ := msg.Metadata["to"].(string) + if to == "" { + return "", errors.New("email recipient missing from metadata.to") + } + + if e.suppression != nil { + yep, err := e.suppression.IsSuppressed(ctx, to) + if err != nil { + // On lookup error, fail safe by NOT sending — better than + // emailing an unsubscribed address. + return "", fmt.Errorf("suppression lookup: %w", err) + } + if yep { + return "suppressed:" + to, nil + } + } + + tmpl := templateForType(msg.Type) + if tmpl == "" { + return "", fmt.Errorf("no template for type %q", msg.Type) + } + subject := msg.Subject + if subject == "" { + subject = defaultSubject(msg.Type, msg.Metadata) + } + data := map[string]any{} + for k, v := range msg.Metadata { + data[k] = v + } + return e.dispatcher.SendGuest(ctx, to, subject, tmpl, data) +} + +func templateForType(t Type) TemplateName { + switch t { + case TypeInvitation: + return TmplInvitation + case TypeConfirmation: + return TmplConfirmation + case TypeReminder: + return TmplReminder + case TypeVerification: + return TmplVerification + } + return "" +} + +func defaultSubject(t Type, meta map[string]any) string { + event, _ := meta["EventName"].(string) + switch t { + case TypeInvitation: + if event != "" { + return "You're invited — " + event + } + return "You're invited" + case TypeConfirmation: + if event != "" { + return "RSVP confirmed — " + event + } + return "RSVP confirmed" + case TypeReminder: + if event != "" { + return "Reminder: " + event + } + return "Reminder" + } + return "GuestGuard" +} diff --git a/internal/notification/sender_twilio.go b/internal/notification/sender_twilio.go new file mode 100644 index 0000000..b6e8fd3 --- /dev/null +++ b/internal/notification/sender_twilio.go @@ -0,0 +1,108 @@ +package notification + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/twilio/twilio-go" + openapi "github.com/twilio/twilio-go/rest/api/v2010" +) + +// TwilioConfig configures the SMS sender. AccountSID + AuthToken pair +// authenticate the REST client; FromNumber is the verified Twilio number. +type TwilioConfig struct { + AccountSID string + AuthToken string + FromNumber string + MaxAttempts int +} + +// TwilioSender implements notification.Sender for ChannelSMS. Retries on +// transient errors with exponential backoff (1s, 5s, 30s, 5m, 30m). Twilio +// surfaces a numeric ErrorCode for permanent failures (e.g. 21610 +// unsubscribed, 21408 disabled region) — those return immediately. +type TwilioSender struct { + client *twilio.RestClient + from string + maxAttempt int +} + +func NewTwilioSender(cfg TwilioConfig) (*TwilioSender, error) { + if cfg.AccountSID == "" || cfg.AuthToken == "" || cfg.FromNumber == "" { + return nil, errors.New("twilio: AccountSID / AuthToken / FromNumber are required") + } + cli := twilio.NewRestClientWithParams(twilio.ClientParams{ + Username: cfg.AccountSID, + Password: cfg.AuthToken, + }) + max := cfg.MaxAttempts + if max <= 0 { + max = 5 + } + return &TwilioSender{client: cli, from: cfg.FromNumber, maxAttempt: max}, nil +} + +func (t *TwilioSender) Send(ctx context.Context, msg OutboundMessage) (string, error) { + if msg.GuestID == uuid.Nil { + return "", errors.New("missing guest id") + } + if msg.Channel != ChannelSMS { + return "", fmt.Errorf("TwilioSender does not handle channel %q", msg.Channel) + } + to, _ := msg.Metadata["phone"].(string) + if to == "" { + return "", errors.New("sms recipient missing from metadata.phone") + } + body := msg.Body + if body == "" { + body = msg.Subject + } + if body == "" { + return "", errors.New("sms body is empty") + } + + params := &openapi.CreateMessageParams{} + params.SetTo(to) + params.SetFrom(t.from) + params.SetBody(body) + + backoff := []time.Duration{0, time.Second, 5 * time.Second, 30 * time.Second, 5 * time.Minute, 30 * time.Minute} + var lastErr error + for attempt := 0; attempt < t.maxAttempt; attempt++ { + if attempt < len(backoff) && backoff[attempt] > 0 { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(backoff[attempt]): + } + } + resp, err := t.client.Api.CreateMessage(params) + if err == nil && resp != nil && resp.Sid != nil { + return *resp.Sid, nil + } + lastErr = err + if !isTwilioRetryable(err) { + return "", fmt.Errorf("twilio: send: %w", err) + } + } + return "", fmt.Errorf("twilio: send (after %d attempts): %w", t.maxAttempt, lastErr) +} + +// isTwilioRetryable returns true for transient failures (network, 5xx). +// Twilio's permanent error codes (21xxx range) are not retried. +func isTwilioRetryable(err error) bool { + if err == nil { + return false + } + msg := err.Error() + // Cheap heuristic: permanent codes are in the 21xxx range; everything + // else (timeouts, 503s, DNS hiccups) is fair game. + if strings.Contains(msg, "21610") || strings.Contains(msg, "21408") || strings.Contains(msg, "21211") { + return false + } + return true +} diff --git a/internal/notification/smtp_sender.go b/internal/notification/smtp_sender.go new file mode 100644 index 0000000..e7b60ee --- /dev/null +++ b/internal/notification/smtp_sender.go @@ -0,0 +1,238 @@ +package notification + +import ( + "context" + "crypto/rand" + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "net" + "net/smtp" + "strconv" + "strings" + "time" +) + +// SMTPConfig describes one SMTP relay. Username/Password are optional — +// local relays like Mailpit accept anonymous SMTP. TLS modes: +// +// - "starttls" upgrade after EHLO (most relays, default) +// - "implicit" TLS handshake before SMTP (port 465) +// - "none" plain socket — only acceptable on a trusted LAN +type SMTPConfig struct { + Host string + Port int + Username string + Password string + FromEmail string + FromName string + TLS string +} + +// SMTPEmailSender implements auth.EmailSender (verification/reset) AND +// GuestEmailDispatcher (invitation/confirmation/reminder) on top of any +// SMTP relay. Used for Mailpit in dev; works against Gmail, Fastmail, etc. +// in production if the user prefers plain SMTP over an HTTP API. +type SMTPEmailSender struct { + cfg SMTPConfig + tpls *Templates + from string +} + +func NewSMTPEmailSender(cfg SMTPConfig, tpls *Templates) (*SMTPEmailSender, error) { + if cfg.Host == "" { + return nil, errors.New("smtp: Host required") + } + if cfg.Port <= 0 { + cfg.Port = 587 + } + if cfg.FromEmail == "" { + return nil, errors.New("smtp: FromEmail required") + } + if cfg.TLS == "" { + cfg.TLS = "starttls" + } + from := cfg.FromEmail + if cfg.FromName != "" { + from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail) + } + return &SMTPEmailSender{cfg: cfg, tpls: tpls, from: from}, nil +} + +// --- auth.EmailSender --- + +func (s *SMTPEmailSender) SendVerification(ctx context.Context, to, name, link string) error { + return s.sendTemplated(ctx, to, "Verify your GuestGuard email", + TmplVerification, map[string]any{"Name": name, "Link": link}) +} + +func (s *SMTPEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error { + return s.sendTemplated(ctx, to, "Reset your GuestGuard password", + TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"}) +} + +// --- GuestEmailDispatcher --- + +func (s *SMTPEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + return s.sendTemplatedReturnID(ctx, to, subject, name, data) +} + +// --- internals --- + +func (s *SMTPEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error { + _, err := s.sendTemplatedReturnID(ctx, to, subject, name, data) + return err +} + +func (s *SMTPEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) { + if data == nil { + data = map[string]any{} + } + data["Subject"] = subject + html, text, err := s.tpls.Render(name, data) + if err != nil { + return "", err + } + msgID := generateMessageID(s.cfg.FromEmail) + body := buildMIMEMessage(mimeMessage{ + MessageID: msgID, + From: s.from, + To: to, + Subject: subject, + Text: text, + HTML: html, + }) + if err := s.dial(ctx, []string{to}, body); err != nil { + return "", err + } + return msgID, nil +} + +func (s *SMTPEmailSender) dial(ctx context.Context, to []string, body []byte) error { + addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(s.cfg.Port)) + d := &net.Dialer{Timeout: 10 * time.Second} + deadline, ok := ctx.Deadline() + if ok { + d.Deadline = deadline + } + + var ( + conn net.Conn + err error + ) + switch strings.ToLower(s.cfg.TLS) { + case "implicit": + conn, err = tls.DialWithDialer(d, "tcp", addr, &tls.Config{ServerName: s.cfg.Host}) + default: + conn, err = d.DialContext(ctx, "tcp", addr) + } + if err != nil { + return fmt.Errorf("smtp: dial: %w", err) + } + + c, err := smtp.NewClient(conn, s.cfg.Host) + if err != nil { + conn.Close() + return fmt.Errorf("smtp: new client: %w", err) + } + defer c.Close() + + if strings.ToLower(s.cfg.TLS) == "starttls" { + if ok, _ := c.Extension("STARTTLS"); ok { + if err := c.StartTLS(&tls.Config{ServerName: s.cfg.Host}); err != nil { + return fmt.Errorf("smtp: starttls: %w", err) + } + } + } + + if s.cfg.Username != "" { + auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host) + if err := c.Auth(auth); err != nil { + return fmt.Errorf("smtp: auth: %w", err) + } + } + + if err := c.Mail(s.cfg.FromEmail); err != nil { + return fmt.Errorf("smtp: MAIL FROM: %w", err) + } + for _, rcpt := range to { + if err := c.Rcpt(rcpt); err != nil { + return fmt.Errorf("smtp: RCPT TO %s: %w", rcpt, err) + } + } + w, err := c.Data() + if err != nil { + return fmt.Errorf("smtp: DATA: %w", err) + } + if _, err := w.Write(body); err != nil { + _ = w.Close() + return fmt.Errorf("smtp: write body: %w", err) + } + if err := w.Close(); err != nil { + return fmt.Errorf("smtp: close body: %w", err) + } + return c.Quit() +} + +type mimeMessage struct { + MessageID string + From string + To string + Subject string + Text string + HTML string +} + +// buildMIMEMessage assembles an RFC 5322 message with a multipart/alternative +// body so receiving clients pick HTML when they can and fall back to text +// otherwise. +func buildMIMEMessage(m mimeMessage) []byte { + boundary := randomBoundary() + var b strings.Builder + b.WriteString("Message-ID: <" + m.MessageID + ">\r\n") + b.WriteString("Date: " + time.Now().UTC().Format(time.RFC1123Z) + "\r\n") + b.WriteString("From: " + m.From + "\r\n") + b.WriteString("To: " + m.To + "\r\n") + b.WriteString("Subject: " + m.Subject + "\r\n") + b.WriteString("MIME-Version: 1.0\r\n") + b.WriteString("Content-Type: multipart/alternative; boundary=" + boundary + "\r\n") + b.WriteString("\r\n") + + b.WriteString("--" + boundary + "\r\n") + b.WriteString("Content-Type: text/plain; charset=UTF-8\r\n") + b.WriteString("Content-Transfer-Encoding: 8bit\r\n") + b.WriteString("\r\n") + b.WriteString(m.Text) + if !strings.HasSuffix(m.Text, "\n") { + b.WriteString("\r\n") + } + + b.WriteString("--" + boundary + "\r\n") + b.WriteString("Content-Type: text/html; charset=UTF-8\r\n") + b.WriteString("Content-Transfer-Encoding: 8bit\r\n") + b.WriteString("\r\n") + b.WriteString(m.HTML) + if !strings.HasSuffix(m.HTML, "\n") { + b.WriteString("\r\n") + } + + b.WriteString("--" + boundary + "--\r\n") + return []byte(b.String()) +} + +func randomBoundary() string { + var buf [16]byte + _, _ = rand.Read(buf[:]) + return "gg=" + hex.EncodeToString(buf[:]) +} + +func generateMessageID(from string) string { + var buf [12]byte + _, _ = rand.Read(buf[:]) + domain := "guestguard.local" + if at := strings.LastIndexByte(from, '@'); at >= 0 && at+1 < len(from) { + domain = from[at+1:] + } + return hex.EncodeToString(buf[:]) + "@" + domain +} diff --git a/internal/notification/smtp_sender_test.go b/internal/notification/smtp_sender_test.go new file mode 100644 index 0000000..b0f43e9 --- /dev/null +++ b/internal/notification/smtp_sender_test.go @@ -0,0 +1,49 @@ +package notification + +import ( + "strings" + "testing" +) + +func TestBuildMIMEMessageStructure(t *testing.T) { + body := buildMIMEMessage(mimeMessage{ + MessageID: "abc@example.test", + From: "GuestGuard ", + To: "to@example.test", + Subject: "Verify your GuestGuard email", + Text: "Hi Mira, please verify.", + HTML: "

Hi Mira, please verify.

", + }) + s := string(body) + checks := []string{ + "Message-ID: ", + "From: GuestGuard ", + "To: to@example.test", + "Subject: Verify your GuestGuard email", + "MIME-Version: 1.0", + "Content-Type: multipart/alternative; boundary=", + "Content-Type: text/plain; charset=UTF-8", + "Content-Type: text/html; charset=UTF-8", + "Hi Mira, please verify.", + "

Hi Mira, please verify.

", + } + for _, want := range checks { + if !strings.Contains(s, want) { + t.Errorf("MIME body missing %q\n---\n%s", want, s) + } + } +} + +func TestGenerateMessageIDIncludesDomain(t *testing.T) { + id := generateMessageID("no-reply@example.test") + if !strings.HasSuffix(id, "@example.test") { + t.Fatalf("message id has wrong domain: %s", id) + } +} + +func TestGenerateMessageIDFallback(t *testing.T) { + id := generateMessageID("not-an-email") + if !strings.HasSuffix(id, "@guestguard.local") { + t.Fatalf("expected fallback domain: %s", id) + } +} diff --git a/internal/notification/suppressions.go b/internal/notification/suppressions.go new file mode 100644 index 0000000..c78e87d --- /dev/null +++ b/internal/notification/suppressions.go @@ -0,0 +1,95 @@ +package notification + +import ( + "context" + "errors" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/alchemistkay/guestguard/internal/storage" +) + +// SuppressionSource categorises why an address landed on the suppression +// list. Bounces + complaints come from provider webhooks; "user" is set +// when a guest clicks an unsubscribe link. +type SuppressionSource string + +const ( + SuppressionBounce SuppressionSource = "bounce" + SuppressionComplaint SuppressionSource = "complaint" + SuppressionManual SuppressionSource = "manual" + SuppressionUser SuppressionSource = "user" +) + +// SuppressionRepo manages the unsubscribes table — a flat suppression list +// of email addresses that must never receive another email regardless of +// notification type. +type SuppressionRepo struct { + pool *pgxpool.Pool +} + +func NewSuppressionRepo(db *storage.DB) *SuppressionRepo { + return &SuppressionRepo{pool: db.Pool} +} + +// IsSuppressed returns true if `email` is on the suppression list. +// Empty / unparseable addresses are treated as not-suppressed; the caller +// is expected to validate before sending. +func (r *SuppressionRepo) IsSuppressed(ctx context.Context, email string) (bool, error) { + email = normaliseEmail(email) + if email == "" { + return false, nil + } + var exists bool + err := r.pool.QueryRow(ctx, + `SELECT EXISTS (SELECT 1 FROM unsubscribes WHERE email = $1)`, + email, + ).Scan(&exists) + if err != nil { + return false, err + } + return exists, nil +} + +// Add records `email` on the suppression list. Idempotent — repeated calls +// keep the earliest entry's timestamp. +func (r *SuppressionRepo) Add(ctx context.Context, email, reason string, src SuppressionSource) error { + email = normaliseEmail(email) + if email == "" { + return errors.New("empty email") + } + _, err := r.pool.Exec(ctx, ` + INSERT INTO unsubscribes (email, reason, source) + VALUES ($1, NULLIF($2, ''), $3) + ON CONFLICT (email) DO NOTHING + `, email, reason, string(src)) + return err +} + +// Get returns the suppression record for `email`, or pgx.ErrNoRows if not +// found. Mostly used by tests / admin tooling. +func (r *SuppressionRepo) Get(ctx context.Context, email string) (string, SuppressionSource, error) { + var reason *string + var source string + err := r.pool.QueryRow(ctx, + `SELECT reason, source FROM unsubscribes WHERE email = $1`, + normaliseEmail(email), + ).Scan(&reason, &source) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", "", err + } + return "", "", err + } + r2 := "" + if reason != nil { + r2 = *reason + } + return r2, SuppressionSource(source), nil +} + +func normaliseEmail(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} diff --git a/internal/notification/templates.go b/internal/notification/templates.go new file mode 100644 index 0000000..4f1f58a --- /dev/null +++ b/internal/notification/templates.go @@ -0,0 +1,116 @@ +package notification + +import ( + "bytes" + "embed" + "fmt" + htmltemplate "html/template" + "io/fs" + "path" + "strings" + texttemplate "text/template" +) + +//go:embed templates +var templatesFS embed.FS + +// TemplateName names one of the email templates. There must be a +// matching `.html` and `.txt` under templates/. +type TemplateName string + +const ( + TmplVerification TemplateName = "verification" + TmplPasswordReset TemplateName = "reset" + TmplInvitation TemplateName = "invitation" + TmplConfirmation TemplateName = "confirmation" + TmplReminder TemplateName = "reminder" +) + +// Templates renders branded transactional emails for both HTML and +// plaintext bodies. Loaded once at construction; thread-safe afterwards. +// +// Each page-level HTML template gets its own *html/template.Template +// holding a copy of the `base.html` partial. This avoids html/template's +// per-template contextual-escape pass interfering between pages that all +// define a sibling named "body". +type Templates struct { + html map[string]*htmltemplate.Template // page-name (no ext) -> root + text *texttemplate.Template +} + +func NewTemplates() (*Templates, error) { + root, err := fs.Sub(templatesFS, "templates") + if err != nil { + return nil, fmt.Errorf("templates fs: %w", err) + } + + baseHTML, err := fs.ReadFile(root, "base.html") + if err != nil { + return nil, fmt.Errorf("read base.html: %w", err) + } + + out := &Templates{ + html: make(map[string]*htmltemplate.Template), + text: texttemplate.New("__root__"), + } + + walk := func(p string, d fs.DirEntry, _ error) error { + if d == nil || d.IsDir() { + return nil + } + base := path.Base(p) + if base == "base.html" { + return nil // partial — folded into each page template below + } + b, err := fs.ReadFile(root, p) + if err != nil { + return err + } + switch { + case strings.HasSuffix(p, ".html"): + name := strings.TrimSuffix(base, ".html") + t := htmltemplate.New(name) + if _, err := t.Parse(string(baseHTML)); err != nil { + return fmt.Errorf("parse _base for %s: %w", p, err) + } + if _, err := t.Parse(string(b)); err != nil { + return fmt.Errorf("parse %s: %w", p, err) + } + out.html[name] = t + case strings.HasSuffix(p, ".txt"): + if _, err := out.text.New(base).Parse(string(b)); err != nil { + return fmt.Errorf("parse %s: %w", p, err) + } + } + return nil + } + if err := fs.WalkDir(root, ".", walk); err != nil { + return nil, err + } + return out, nil +} + +// Render returns (htmlBody, textBody) for the named template using data. +func (t *Templates) Render(name TemplateName, data map[string]any) (htmlBody, textBody string, err error) { + if data == nil { + data = map[string]any{} + } + if _, ok := data["Subject"]; !ok { + data["Subject"] = "GuestGuard" + } + + htmlTpl, ok := t.html[string(name)] + if !ok { + return "", "", fmt.Errorf("unknown html template %q", name) + } + + var hBuf, tBuf bytes.Buffer + // Each page-root template's entry point is "base" (defined by base.html). + if err := htmlTpl.ExecuteTemplate(&hBuf, "base", data); err != nil { + return "", "", fmt.Errorf("render html %s: %w", name, err) + } + if err := t.text.ExecuteTemplate(&tBuf, string(name)+".txt", data); err != nil { + return "", "", fmt.Errorf("render text %s: %w", name, err) + } + return hBuf.String(), tBuf.String(), nil +} diff --git a/internal/notification/templates/base.html b/internal/notification/templates/base.html new file mode 100644 index 0000000..469791b --- /dev/null +++ b/internal/notification/templates/base.html @@ -0,0 +1,36 @@ +{{define "base"}} + + + + +{{.Subject}} + + + + +
+ + + + + + + + + + +
+ + GuestGuard +
+ {{block "body" .}}{{end}} +
+
+

+ You're receiving this because of activity on your GuestGuard account. + {{if .UnsubscribeLink}}
If you'd rather not get emails like this, unsubscribe here.{{end}} +

+
+
+ +{{end}} diff --git a/internal/notification/templates/confirmation.html b/internal/notification/templates/confirmation.html new file mode 100644 index 0000000..8c433ac --- /dev/null +++ b/internal/notification/templates/confirmation.html @@ -0,0 +1,11 @@ +{{define "body"}} +

RSVP received

+

Hi {{.GuestName}},

+

Thanks for letting {{.HostName}} know — your RSVP for {{.EventName}} is confirmed as {{.Response}}{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.

+{{if or .Venue .EventDate}} +

+ {{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}} +

+{{end}} +

You'll get a reminder closer to the date. If your plans change, use the same invitation link to update your reply.

+{{end}} diff --git a/internal/notification/templates/confirmation.txt b/internal/notification/templates/confirmation.txt new file mode 100644 index 0000000..3e98dde --- /dev/null +++ b/internal/notification/templates/confirmation.txt @@ -0,0 +1,10 @@ +Hi {{.GuestName}}, + +Your RSVP for "{{.EventName}}" is confirmed as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}. + +{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}} + +You'll get a reminder closer to the date. If your plans change, use the +same invitation link to update your reply. + +— GuestGuard diff --git a/internal/notification/templates/invitation.html b/internal/notification/templates/invitation.html new file mode 100644 index 0000000..c922e85 --- /dev/null +++ b/internal/notification/templates/invitation.html @@ -0,0 +1,15 @@ +{{define "body"}} +

✦ You're invited

+

{{.EventName}}

+{{if or .Venue .EventDate}} +

+ {{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}} +

+{{end}} +

Hi {{.GuestName}}, {{.HostName}} would love to know if you can make it. Use the personal link below to RSVP.

+

+ RSVP now +

+

Or paste this URL into your browser:

+

{{.Link}}

+{{end}} diff --git a/internal/notification/templates/invitation.txt b/internal/notification/templates/invitation.txt new file mode 100644 index 0000000..fc8bf65 --- /dev/null +++ b/internal/notification/templates/invitation.txt @@ -0,0 +1,11 @@ +You're invited — {{.EventName}} + +Hi {{.GuestName}}, + +{{.HostName}} would love to know if you can make it{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}. + +RSVP here: + + {{.Link}} + +— GuestGuard diff --git a/internal/notification/templates/reminder.html b/internal/notification/templates/reminder.html new file mode 100644 index 0000000..5f9f068 --- /dev/null +++ b/internal/notification/templates/reminder.html @@ -0,0 +1,7 @@ +{{define "body"}} +

Reminder: {{.EventName}}

+

Hi {{.GuestName}},

+

Just a quick reminder that {{.EventName}} is coming up{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.

+{{if .Response}}

You're down as {{.Response}}{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.

{{end}} +

Need to change your plans? Use your invitation link.

+{{end}} diff --git a/internal/notification/templates/reminder.txt b/internal/notification/templates/reminder.txt new file mode 100644 index 0000000..b0b1d28 --- /dev/null +++ b/internal/notification/templates/reminder.txt @@ -0,0 +1,7 @@ +Hi {{.GuestName}}, + +Reminder: {{.EventName}}{{if .EventDate}} — {{.EventDate}}{{end}}{{if .Venue}} at {{.Venue}}{{end}}. + +{{if .Response}}You're down as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.{{end}} + +— GuestGuard diff --git a/internal/notification/templates/reset.html b/internal/notification/templates/reset.html new file mode 100644 index 0000000..df9402f --- /dev/null +++ b/internal/notification/templates/reset.html @@ -0,0 +1,11 @@ +{{define "body"}} +

Reset your password

+

Hi {{.Name}},

+

We received a request to reset the password on your GuestGuard account. Tap the button to choose a new one — the link is valid for {{.ExpiryHumane}}.

+

+ Choose a new password +

+

Or paste this URL into your browser:

+

{{.Link}}

+

If you didn't ask to reset your password, you can ignore this email — your current password is unchanged.

+{{end}} diff --git a/internal/notification/templates/reset.txt b/internal/notification/templates/reset.txt new file mode 100644 index 0000000..882abe0 --- /dev/null +++ b/internal/notification/templates/reset.txt @@ -0,0 +1,11 @@ +Hi {{.Name}}, + +We received a request to reset your GuestGuard password. The link is valid +for {{.ExpiryHumane}}: + + {{.Link}} + +If you didn't ask to reset your password, you can ignore this email — your +current password is unchanged. + +— GuestGuard diff --git a/internal/notification/templates/verification.html b/internal/notification/templates/verification.html new file mode 100644 index 0000000..b28333d --- /dev/null +++ b/internal/notification/templates/verification.html @@ -0,0 +1,11 @@ +{{define "body"}} +

Verify your email

+

Hi {{.Name}}, welcome to GuestGuard.

+

To finish setting up your account, please confirm this is your email address.

+

+ Verify email +

+

Or paste this URL into your browser:

+

{{.Link}}

+

If you didn't sign up for GuestGuard, you can ignore this email.

+{{end}} diff --git a/internal/notification/templates/verification.txt b/internal/notification/templates/verification.txt new file mode 100644 index 0000000..93e537d --- /dev/null +++ b/internal/notification/templates/verification.txt @@ -0,0 +1,9 @@ +Hi {{.Name}}, welcome to GuestGuard. + +Please verify your email by visiting: + + {{.Link}} + +If you didn't sign up for GuestGuard, you can ignore this email. + +— GuestGuard diff --git a/internal/notification/templates_test.go b/internal/notification/templates_test.go new file mode 100644 index 0000000..9a8b527 --- /dev/null +++ b/internal/notification/templates_test.go @@ -0,0 +1,101 @@ +package notification + +import ( + "strings" + "testing" +) + +func TestRenderAllTemplates(t *testing.T) { + tpls, err := NewTemplates() + if err != nil { + t.Fatalf("NewTemplates: %v", err) + } + + cases := []struct { + name TemplateName + data map[string]any + wantHTML []string // substrings expected in HTML body + wantText []string // substrings expected in text body + }{ + { + name: TmplVerification, + data: map[string]any{ + "Name": "Kay", + "Link": "https://example.test/verify-email?token=x", + "Subject": "Verify your GuestGuard email", + "UnsubscribeLink": "https://example.test/unsubscribe/abc", + }, + wantHTML: []string{"Verify your email", "Kay", "Verify email", "https://example.test/verify-email?token=x", "unsubscribe here"}, + wantText: []string{"Hi Kay", "https://example.test/verify-email?token=x"}, + }, + { + name: TmplPasswordReset, + data: map[string]any{ + "Name": "Kay", + "Link": "https://example.test/reset-password/abc", + "ExpiryHumane": "1 hour", + }, + wantHTML: []string{"Reset your password", "1 hour", "https://example.test/reset-password/abc"}, + wantText: []string{"reset your GuestGuard password", "1 hour"}, + }, + { + name: TmplInvitation, + data: map[string]any{ + "GuestName": "Mira", + "HostName": "Kay", + "EventName": "Beach Day", + "Venue": "Ocean Park", + "EventDate": "Sat 14 Jun, 4pm", + "Link": "https://example.test/rsvp/tok_x", + }, + wantHTML: []string{"You're invited", "Beach Day", "Mira", "Ocean Park", "RSVP now", "https://example.test/rsvp/tok_x"}, + wantText: []string{"Beach Day", "Mira", "RSVP here", "https://example.test/rsvp/tok_x"}, + }, + { + name: TmplConfirmation, + data: map[string]any{ + "GuestName": "Mira", + "HostName": "Kay", + "EventName": "Beach Day", + "Venue": "Ocean Park", + "EventDate": "Sat 14 Jun, 4pm", + "Response": "attending", + "PlusOnes": 2, + }, + wantHTML: []string{"RSVP received", "Beach Day", "attending", "2 plus-ones"}, + wantText: []string{"Beach Day", "attending", "+2"}, + }, + { + name: TmplReminder, + data: map[string]any{ + "GuestName": "Mira", + "EventName": "Beach Day", + "EventDate": "Sat 14 Jun, 4pm", + "Venue": "Ocean Park", + "Response": "attending", + "PlusOnes": 1, + }, + wantHTML: []string{"Reminder", "Beach Day", "Ocean Park", "1 plus-one"}, + wantText: []string{"Reminder", "Beach Day", "(+1)"}, + }, + } + + for _, tc := range cases { + t.Run(string(tc.name), func(t *testing.T) { + html, text, err := tpls.Render(tc.name, tc.data) + if err != nil { + t.Fatalf("render: %v", err) + } + for _, s := range tc.wantHTML { + if !strings.Contains(html, s) { + t.Errorf("html missing %q\n---\n%s", s, html) + } + } + for _, s := range tc.wantText { + if !strings.Contains(text, s) { + t.Errorf("text missing %q\n---\n%s", s, text) + } + } + }) + } +} diff --git a/internal/notification/unsubscribe.go b/internal/notification/unsubscribe.go new file mode 100644 index 0000000..f91013f --- /dev/null +++ b/internal/notification/unsubscribe.go @@ -0,0 +1,60 @@ +package notification + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "errors" + "strings" +) + +// UnsubscribeSigner mints + verifies tamper-proof unsubscribe tokens. Each +// token encodes the email address and an HMAC-SHA256 of the address under +// a server-side secret. The token has no TTL — unsubscribe links should +// keep working forever. +// +// Token shape: base64url(email) + "." + base64url(hmac) +type UnsubscribeSigner struct { + secret []byte +} + +func NewUnsubscribeSigner(secret string) *UnsubscribeSigner { + return &UnsubscribeSigner{secret: []byte(secret)} +} + +// Sign returns a URL-safe token for the email address. Empty input → empty +// token (caller should validate input first). +func (s *UnsubscribeSigner) Sign(email string) string { + if email == "" { + return "" + } + email = normaliseEmail(email) + mac := hmac.New(sha256.New, s.secret) + mac.Write([]byte(email)) + return base64.RawURLEncoding.EncodeToString([]byte(email)) + + "." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} + +// Verify decodes the token and confirms the HMAC matches, returning the +// owning email address. +func (s *UnsubscribeSigner) Verify(token string) (string, error) { + dot := strings.IndexByte(token, '.') + if dot < 0 { + return "", errors.New("malformed unsubscribe token") + } + emailB, err := base64.RawURLEncoding.DecodeString(token[:dot]) + if err != nil { + return "", err + } + sigB, err := base64.RawURLEncoding.DecodeString(token[dot+1:]) + if err != nil { + return "", err + } + mac := hmac.New(sha256.New, s.secret) + mac.Write(emailB) + want := mac.Sum(nil) + if !hmac.Equal(sigB, want) { + return "", errors.New("unsubscribe signature mismatch") + } + return string(emailB), nil +} diff --git a/internal/ratelimit/middleware.go b/internal/ratelimit/middleware.go new file mode 100644 index 0000000..936376d --- /dev/null +++ b/internal/ratelimit/middleware.go @@ -0,0 +1,69 @@ +package ratelimit + +import ( + "encoding/json" + "log/slog" + "net/http" + "strconv" + "time" +) + +// KeyFunc derives the rate-limit key from a request — e.g. IP, IP+email, +// authenticated user id, path token, etc. An empty return value bypasses +// the limiter (handy when a key isn't available yet, like an email that +// hasn't been parsed out of the body). +type KeyFunc func(r *http.Request) string + +// Rule names a single sliding-window budget. +type Rule struct { + Name string + Limit int + Window time.Duration +} + +// Middleware returns an http middleware that applies `rule` to incoming +// requests using `keyFn`. On limit, it writes 429 with a Retry-After header +// and a JSON body. If the limiter itself errors, requests are allowed +// (fail-open) — degraded protection is better than total outage. +func (l *Limiter) Middleware(rule Rule, keyFn KeyFunc, logger *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := keyFn(r) + if key == "" { + next.ServeHTTP(w, r) + return + } + res, err := l.Allow(r.Context(), rule.Name, key, rule.Limit, rule.Window) + if err != nil { + if logger != nil { + logger.Warn("ratelimit error (failing open)", + "rule", rule.Name, "err", err) + } + next.ServeHTTP(w, r) + return + } + if !res.Allowed { + retrySecs := int(res.RetryAfter.Round(time.Second).Seconds()) + if retrySecs < 1 { + retrySecs = 1 + } + w.Header().Set("Retry-After", strconv.Itoa(retrySecs)) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "rate limit exceeded", + "retry_after": retrySecs, + }) + if logger != nil { + logger.Info("ratelimit blocked", + "rule", rule.Name, + "key", key, + "retry_after", retrySecs, + ) + } + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..058fae5 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,134 @@ +// Package ratelimit implements a sliding-window rate limiter backed by +// Redis sorted sets. Each call is atomic via a Lua script: it sweeps +// entries older than the window, returns the current count, and (when +// under the limit) records the new hit. Block C's "INCR + EXPIRE or a Lua +// script for atomicity" requirement. +package ratelimit + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +// Result is the outcome of one Allow check. +type Result struct { + Allowed bool + Count int // current count within the window (post-increment when allowed) + Limit int + RetryAfter time.Duration // populated when Allowed=false +} + +// Limiter checks rate-limit budgets against Redis. +type Limiter struct { + client *redis.Client + script *redis.Script + prefix string + now func() time.Time +} + +// New builds a Limiter against the given Redis client. The prefix namespaces +// all keys (defaults to "rl"). +func New(client *redis.Client, prefix string) *Limiter { + if prefix == "" { + prefix = "rl" + } + return &Limiter{ + client: client, + script: redis.NewScript(slidingWindowScript), + prefix: prefix, + now: time.Now, + } +} + +// Allow consumes one unit of budget under (name, key) against `limit` events +// per `window`. Returns Allowed=true and the new count, or Allowed=false +// with RetryAfter set to roughly the duration until the oldest hit ages out. +func (l *Limiter) Allow(ctx context.Context, name, key string, limit int, window time.Duration) (Result, error) { + if limit <= 0 { + return Result{}, errors.New("ratelimit: limit must be positive") + } + if window <= 0 { + return Result{}, errors.New("ratelimit: window must be positive") + } + member, err := randomMember() + if err != nil { + return Result{}, err + } + + now := l.now().UnixMilli() + windowMS := window.Milliseconds() + redisKey := fmt.Sprintf("%s:%s:%s", l.prefix, name, key) + + out, err := l.script.Run(ctx, l.client, + []string{redisKey}, + now, windowMS, limit, member, + ).Int64Slice() + if err != nil { + return Result{}, fmt.Errorf("ratelimit: redis: %w", err) + } + if len(out) != 3 { + return Result{}, fmt.Errorf("ratelimit: bad lua reply: %v", out) + } + + r := Result{ + Count: int(out[1]), + Limit: limit, + } + if out[0] == 0 { + r.Allowed = true + } else { + r.RetryAfter = time.Duration(out[2]) * time.Millisecond + if r.RetryAfter <= 0 { + r.RetryAfter = time.Second + } + } + return r, nil +} + +func randomMember() (string, error) { + var buf [12]byte + if _, err := rand.Read(buf[:]); err != nil { + return "", err + } + return hex.EncodeToString(buf[:]), nil +} + +// Sliding-window check + record, atomic in Redis. +// +// KEYS[1] = bucket key +// ARGV[1] = now (unix ms) +// ARGV[2] = window (ms) +// ARGV[3] = limit +// ARGV[4] = unique member to insert when allowed +// returns { blocked, count, retryAfterMs } +const slidingWindowScript = ` +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local limit = tonumber(ARGV[3]) +local member = ARGV[4] +local cutoff = now - window + +redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff) +local count = redis.call('ZCARD', key) + +if count >= limit then + local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') + local retry = window + if oldest[2] then + retry = (tonumber(oldest[2]) + window) - now + if retry < 1 then retry = 1 end + end + return {1, count, retry} +end + +redis.call('ZADD', key, now, member) +redis.call('PEXPIRE', key, window) +return {0, count + 1, 0} +` diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..081a38d --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,110 @@ +package ratelimit + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func newTestLimiter(t *testing.T) (*Limiter, *miniredis.Miniredis) { + t.Helper() + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("miniredis: %v", err) + } + t.Cleanup(mr.Close) + cli := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { _ = cli.Close() }) + return New(cli, "test"), mr +} + +func TestLimiterAllowsBelowLimit(t *testing.T) { + l, _ := newTestLimiter(t) + ctx := context.Background() + for i := 1; i <= 3; i++ { + r, err := l.Allow(ctx, "signup", "1.2.3.4", 3, time.Minute) + if err != nil { + t.Fatalf("allow #%d: %v", i, err) + } + if !r.Allowed { + t.Fatalf("hit %d should be allowed, got %+v", i, r) + } + if r.Count != i { + t.Fatalf("hit %d count: got %d want %d", i, r.Count, i) + } + } +} + +func TestLimiterBlocksAtLimit(t *testing.T) { + l, _ := newTestLimiter(t) + ctx := context.Background() + for i := 0; i < 3; i++ { + if _, err := l.Allow(ctx, "signup", "ip", 3, time.Minute); err != nil { + t.Fatal(err) + } + } + r, err := l.Allow(ctx, "signup", "ip", 3, time.Minute) + if err != nil { + t.Fatal(err) + } + if r.Allowed { + t.Fatalf("4th hit should be blocked: %+v", r) + } + if r.RetryAfter <= 0 || r.RetryAfter > time.Minute { + t.Fatalf("retry-after out of range: %v", r.RetryAfter) + } +} + +func TestLimiterWindowSlides(t *testing.T) { + l, mr := newTestLimiter(t) + ctx := context.Background() + // Inject a controllable clock so we can advance time in miniredis + + // the limiter consistently. + base := time.Unix(1_700_000_000, 0) + l.now = func() time.Time { return base } + for i := 0; i < 3; i++ { + if _, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute); err != nil { + t.Fatal(err) + } + } + // Slide past the window. miniredis honours TTLs we already set so + // FastForward is the trustworthy primitive here. + l.now = func() time.Time { return base.Add(2 * time.Minute) } + mr.FastForward(2 * time.Minute) + + r, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute) + if err != nil { + t.Fatal(err) + } + if !r.Allowed { + t.Fatalf("expected allow after window: %+v", r) + } +} + +func TestLimiterIsolatesKeys(t *testing.T) { + l, _ := newTestLimiter(t) + ctx := context.Background() + // Exhaust budget for one key — others should not be affected. + for i := 0; i < 2; i++ { + if _, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute); err != nil { + t.Fatal(err) + } + } + blocked, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute) + if err != nil { + t.Fatal(err) + } + if blocked.Allowed { + t.Fatal("key a should be blocked") + } + other, err := l.Allow(ctx, "login", "b@x.com", 2, time.Minute) + if err != nil { + t.Fatal(err) + } + if !other.Allowed { + t.Fatalf("unrelated key b should still be allowed: %+v", other) + } +} diff --git a/internal/storage/auth_tokens.go b/internal/storage/auth_tokens.go new file mode 100644 index 0000000..0d1569a --- /dev/null +++ b/internal/storage/auth_tokens.go @@ -0,0 +1,267 @@ +package storage + +import ( + "context" + "errors" + "net/netip" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/alchemistkay/guestguard/internal/domain" +) + +// EmailVerificationRepo manages single-use email verification tokens. +type EmailVerificationRepo struct { + pool *pgxpool.Pool +} + +func NewEmailVerificationRepo(db *DB) *EmailVerificationRepo { + return &EmailVerificationRepo{pool: db.Pool} +} + +func (r *EmailVerificationRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error { + _, err := r.pool.Exec(ctx, ` + INSERT INTO email_verification_tokens (token_hash, user_id, expires_at) + VALUES ($1, $2, $3) + `, hash, userID, expiresAt) + return err +} + +// Consume atomically marks the token as used and returns the owning user_id. +// Returns ErrAuthTokenNotFound / ErrAuthTokenConsumed / ErrAuthTokenExpired. +func (r *EmailVerificationRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) { + const q = ` + UPDATE email_verification_tokens + SET consumed_at = now() + WHERE token_hash = $1 + AND consumed_at IS NULL + AND expires_at > now() + RETURNING user_id + ` + var uid uuid.UUID + if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool, + "SELECT consumed_at, expires_at FROM email_verification_tokens WHERE token_hash=$1", + hash) + } + return uuid.Nil, err + } + return uid, nil +} + +// PasswordResetRepo manages single-use password-reset tokens. +type PasswordResetRepo struct { + pool *pgxpool.Pool +} + +func NewPasswordResetRepo(db *DB) *PasswordResetRepo { + return &PasswordResetRepo{pool: db.Pool} +} + +func (r *PasswordResetRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error { + _, err := r.pool.Exec(ctx, ` + INSERT INTO password_reset_tokens (token_hash, user_id, expires_at) + VALUES ($1, $2, $3) + `, hash, userID, expiresAt) + return err +} + +func (r *PasswordResetRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) { + const q = ` + UPDATE password_reset_tokens + SET consumed_at = now() + WHERE token_hash = $1 + AND consumed_at IS NULL + AND expires_at > now() + RETURNING user_id + ` + var uid uuid.UUID + if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool, + "SELECT consumed_at, expires_at FROM password_reset_tokens WHERE token_hash=$1", + hash) + } + return uuid.Nil, err + } + return uid, nil +} + +// RefreshTokenRepo manages refresh-token rows. Refresh tokens are rotated: +// every refresh issues a new token and revokes the old one, recording the +// chain in `replaced_by` so we can detect replay (a revoked token being +// presented again triggers a family-wide revocation). +type RefreshTokenRepo struct { + pool *pgxpool.Pool +} + +func NewRefreshTokenRepo(db *DB) *RefreshTokenRepo { + return &RefreshTokenRepo{pool: db.Pool} +} + +type RefreshToken struct { + Hash string + UserID uuid.UUID + ExpiresAt time.Time + RevokedAt *time.Time + ReplacedBy *string + UserAgent string + IPAddress *netip.Addr + CreatedAt time.Time +} + +type CreateRefreshTokenParams struct { + Hash string + UserID uuid.UUID + ExpiresAt time.Time + UserAgent string + IPAddress string +} + +func (r *RefreshTokenRepo) Create(ctx context.Context, p CreateRefreshTokenParams) error { + ip := parseIP(p.IPAddress) + _, err := r.pool.Exec(ctx, ` + INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address) + VALUES ($1, $2, $3, NULLIF($4, ''), $5) + `, p.Hash, p.UserID, p.ExpiresAt, p.UserAgent, ip) + return err +} + +func (r *RefreshTokenRepo) Get(ctx context.Context, hash string) (*RefreshToken, error) { + const q = ` + SELECT token_hash, user_id, expires_at, revoked_at, replaced_by, + COALESCE(user_agent, ''), host(ip_address), created_at + FROM refresh_tokens WHERE token_hash = $1 + ` + var rt RefreshToken + var ipText *string + if err := r.pool.QueryRow(ctx, q, hash).Scan( + &rt.Hash, &rt.UserID, &rt.ExpiresAt, &rt.RevokedAt, &rt.ReplacedBy, + &rt.UserAgent, &ipText, &rt.CreatedAt, + ); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrAuthTokenNotFound + } + return nil, err + } + if ipText != nil && *ipText != "" { + if addr, err := netip.ParseAddr(*ipText); err == nil { + rt.IPAddress = &addr + } + } + return &rt, nil +} + +// Rotate atomically (in a transaction) marks the old token revoked and +// inserts the new one with replaced_by set. Returns ErrAuthTokenNotFound or +// ErrRefreshTokenRevoked if the old token is missing or already revoked. +func (r *RefreshTokenRepo) Rotate(ctx context.Context, oldHash string, next CreateRefreshTokenParams) error { + tx, err := r.pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + var revokedAt *time.Time + var userID uuid.UUID + var expiresAt time.Time + err = tx.QueryRow(ctx, ` + SELECT user_id, expires_at, revoked_at + FROM refresh_tokens WHERE token_hash = $1 FOR UPDATE + `, oldHash).Scan(&userID, &expiresAt, &revokedAt) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return domain.ErrAuthTokenNotFound + } + return err + } + if revokedAt != nil { + // Replay of a revoked refresh token — revoke the entire family. + if _, err := tx.Exec(ctx, ` + UPDATE refresh_tokens SET revoked_at = now() + WHERE user_id = $1 AND revoked_at IS NULL + `, userID); err != nil { + return err + } + if err := tx.Commit(ctx); err != nil { + return err + } + return domain.ErrRefreshTokenRevoked + } + if time.Now().After(expiresAt) { + return domain.ErrAuthTokenExpired + } + if next.UserID != userID { + return errors.New("refresh token user mismatch") + } + + ip := parseIP(next.IPAddress) + if _, err := tx.Exec(ctx, ` + INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address) + VALUES ($1, $2, $3, NULLIF($4, ''), $5) + `, next.Hash, next.UserID, next.ExpiresAt, next.UserAgent, ip); err != nil { + return err + } + if _, err := tx.Exec(ctx, ` + UPDATE refresh_tokens SET revoked_at = now(), replaced_by = $2 + WHERE token_hash = $1 + `, oldHash, next.Hash); err != nil { + return err + } + return tx.Commit(ctx) +} + +func (r *RefreshTokenRepo) Revoke(ctx context.Context, hash string) error { + tag, err := r.pool.Exec(ctx, ` + UPDATE refresh_tokens SET revoked_at = now() + WHERE token_hash = $1 AND revoked_at IS NULL + `, hash) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return domain.ErrAuthTokenNotFound + } + return nil +} + +func (r *RefreshTokenRepo) RevokeAllForUser(ctx context.Context, userID uuid.UUID) error { + _, err := r.pool.Exec(ctx, ` + UPDATE refresh_tokens SET revoked_at = now() + WHERE user_id = $1 AND revoked_at IS NULL + `, userID) + return err +} + +func parseIP(s string) any { + if s == "" { + return nil + } + addr, err := netip.ParseAddr(s) + if err != nil { + return nil + } + return addr.String() +} + +func classifyAuthTokenLookup(ctx context.Context, pool *pgxpool.Pool, q, hash string) error { + var consumedAt *time.Time + var expiresAt time.Time + if err := pool.QueryRow(ctx, q, hash).Scan(&consumedAt, &expiresAt); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return domain.ErrAuthTokenNotFound + } + return err + } + if consumedAt != nil { + return domain.ErrAuthTokenConsumed + } + if time.Now().After(expiresAt) { + return domain.ErrAuthTokenExpired + } + return domain.ErrAuthTokenNotFound +} diff --git a/internal/storage/events.go b/internal/storage/events.go index 2e73d1d..048f2b1 100644 --- a/internal/storage/events.go +++ b/internal/storage/events.go @@ -78,6 +78,24 @@ func (r *EventRepo) Get(ctx context.Context, id uuid.UUID) (*domain.Event, error return ev, nil } +// GetForHost is the authz-aware variant of Get. It returns ErrEventNotFound +// when the event either doesn't exist or doesn't belong to the host — by +// merging both cases we avoid leaking existence on cross-tenant lookups. +func (r *EventRepo) GetForHost(ctx context.Context, id, hostID uuid.UUID) (*domain.Event, error) { + const q = ` + SELECT id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at + FROM events WHERE id = $1 AND host_id = $2 + ` + ev, err := scanEvent(r.pool.QueryRow(ctx, q, id, hostID)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrEventNotFound + } + return nil, err + } + return ev, nil +} + func (r *EventRepo) List(ctx context.Context, hostID uuid.UUID, limit, offset int) ([]*domain.Event, error) { if limit <= 0 || limit > 200 { limit = 50 @@ -132,18 +150,18 @@ type UpdateEventParams struct { Status *domain.EventStatus } -func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParams) (*domain.Event, error) { +func (r *EventRepo) Update(ctx context.Context, id, hostID uuid.UUID, p UpdateEventParams) (*domain.Event, error) { const q = ` UPDATE events SET - name = COALESCE($2, name), - slug = COALESCE($3, slug), - event_date = COALESCE($4, event_date), - venue = COALESCE($5, venue), - max_capacity = COALESCE($6, max_capacity), - settings = COALESCE($7, settings), - status = COALESCE($8, status), + name = COALESCE($3, name), + slug = COALESCE($4, slug), + event_date = COALESCE($5, event_date), + venue = COALESCE($6, venue), + max_capacity = COALESCE($7, max_capacity), + settings = COALESCE($8, settings), + status = COALESCE($9, status), updated_at = now() - WHERE id = $1 + WHERE id = $1 AND host_id = $2 RETURNING id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at ` @@ -156,7 +174,7 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam settingsJSON = b } - row := r.pool.QueryRow(ctx, q, id, + row := r.pool.QueryRow(ctx, q, id, hostID, p.Name, p.Slug, p.EventDate, p.Venue, p.MaxCapacity, settingsJSON, p.Status, ) ev, err := scanEvent(row) @@ -173,8 +191,8 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam return ev, nil } -func (r *EventRepo) Delete(ctx context.Context, id uuid.UUID) error { - tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1`, id) +func (r *EventRepo) Delete(ctx context.Context, id, hostID uuid.UUID) error { + tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1 AND host_id = $2`, id, hostID) if err != nil { return err } diff --git a/internal/storage/guests.go b/internal/storage/guests.go index f178559..7f54619 100644 --- a/internal/storage/guests.go +++ b/internal/storage/guests.go @@ -3,6 +3,8 @@ package storage import ( "context" "errors" + "fmt" + "strings" "time" "github.com/google/uuid" @@ -88,6 +90,87 @@ func (r *GuestRepo) ListByEvent(ctx context.Context, eventID uuid.UUID, limit, o return out, rows.Err() } +// BulkImportRow is one normalised guest in a CSV import batch. +type BulkImportRow struct { + Name string + Email string // empty if absent + Phone string // empty if absent + PlusOnes int +} + +// BulkImportResult reports the outcome of a single import call. The +// SkippedEmails slice records the addresses we silently dropped because a +// guest already exists on the event — useful for the success summary. +type BulkImportResult struct { + Added int + Skipped int + SkippedEmails []string +} + +// BulkImportGuests inserts up to len(rows) guest rows into the event in a +// single transaction. Rows whose email matches an existing guest on the +// event are skipped (idempotent re-imports). Within the batch, duplicate +// emails after the first are also skipped. Either the entire batch +// commits or none of it does. +// +// Empty email is treated as "no email" and not deduped — those rows +// always insert (the host might be entering phone-only guests). +func (r *GuestRepo) BulkImportGuests(ctx context.Context, eventID uuid.UUID, rows []BulkImportRow) (BulkImportResult, error) { + res := BulkImportResult{} + if len(rows) == 0 { + return res, nil + } + + tx, err := r.pool.Begin(ctx) + if err != nil { + return res, err + } + defer tx.Rollback(ctx) + + // Fetch existing emails on the event into a set for O(1) dedup. + existing := map[string]struct{}{} + exRows, err := tx.Query(ctx, + `SELECT lower(email) FROM guests WHERE event_id = $1 AND email IS NOT NULL AND email <> ''`, + eventID) + if err != nil { + return res, fmt.Errorf("load existing emails: %w", err) + } + for exRows.Next() { + var e string + if err := exRows.Scan(&e); err != nil { + exRows.Close() + return res, err + } + existing[e] = struct{}{} + } + exRows.Close() + + const ins = ` + INSERT INTO guests (event_id, name, email, phone, plus_ones) + VALUES ($1, $2, NULLIF($3, ''), NULLIF($4, ''), $5) + ` + for _, row := range rows { + email := strings.ToLower(strings.TrimSpace(row.Email)) + if email != "" { + if _, dup := existing[email]; dup { + res.Skipped++ + res.SkippedEmails = append(res.SkippedEmails, email) + continue + } + existing[email] = struct{}{} + } + if _, err := tx.Exec(ctx, ins, eventID, row.Name, email, row.Phone, row.PlusOnes); err != nil { + return BulkImportResult{}, fmt.Errorf("insert guest %q: %w", row.Name, err) + } + res.Added++ + } + + if err := tx.Commit(ctx); err != nil { + return BulkImportResult{}, err + } + return res, nil +} + func scanGuest(s rowScanner) (*domain.Guest, error) { var g domain.Guest err := s.Scan( @@ -100,6 +183,135 @@ func scanGuest(s rowScanner) (*domain.Guest, error) { return &g, nil } +// UpdateGuestParams patches a guest. Nil fields are left untouched. +// An empty string for Email / Phone clears the column to NULL, matching +// the frontend "clear this field" UX. +type UpdateGuestParams struct { + Name *string + Email *string + Phone *string + PlusOnes *int +} + +// Update applies the patch to (guestID, eventID). Event scoping in the +// WHERE clause prevents a host from patching guests on another host's +// event even if they guess the guest_id. Returns ErrGuestNotFound when +// the guest doesn't exist on the event. +func (r *GuestRepo) Update(ctx context.Context, eventID, guestID uuid.UUID, p UpdateGuestParams) (*domain.Guest, error) { + sets := []string{} + args := []any{guestID, eventID} + add := func(col string, val any) { + args = append(args, val) + sets = append(sets, fmt.Sprintf("%s = $%d", col, len(args))) + } + if p.Name != nil { + add("name", strings.TrimSpace(*p.Name)) + } + if p.Email != nil { + if strings.TrimSpace(*p.Email) == "" { + sets = append(sets, "email = NULL") + } else { + add("email", strings.ToLower(strings.TrimSpace(*p.Email))) + } + } + if p.Phone != nil { + if strings.TrimSpace(*p.Phone) == "" { + sets = append(sets, "phone = NULL") + } else { + add("phone", strings.TrimSpace(*p.Phone)) + } + } + if p.PlusOnes != nil { + add("plus_ones", *p.PlusOnes) + } + if len(sets) == 0 { + return r.Get(ctx, guestID) + } + q := fmt.Sprintf(` + UPDATE guests SET %s + WHERE id = $1 AND event_id = $2 + RETURNING id, event_id, name, email, phone, plus_ones, dietary_notes, table_number, created_at + `, strings.Join(sets, ", ")) + g, err := scanGuest(r.pool.QueryRow(ctx, q, args...)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrGuestNotFound + } + return nil, err + } + return g, nil +} + +// Delete removes a guest from an event. Cascade-deletes any tokens, +// rsvps, access_logs, and notifications tied to the guest. Event scoping +// in the WHERE clause stops cross-tenant deletes. +func (r *GuestRepo) Delete(ctx context.Context, eventID, guestID uuid.UUID) error { + tag, err := r.pool.Exec(ctx, + `DELETE FROM guests WHERE id = $1 AND event_id = $2`, + guestID, eventID) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return domain.ErrGuestNotFound + } + return nil +} + +// GuestForInvitation is the minimum data the bulk-invite path needs about +// each candidate. Pulled in a single query joined against tokens so the +// caller knows up-front who's already received an invitation. +type GuestForInvitation struct { + ID uuid.UUID + Name string + Email string // empty when the guest has no email on file + HasToken bool +} + +// ListGuestsForInvitation returns every guest on `eventID`, joined with +// the tokens table so the caller can skip guests that already have one. +// When `onlyIDs` is non-nil/non-empty, the result is filtered to that +// subset (used for explicit selection in the bulk-send UI). +func (r *GuestRepo) ListGuestsForInvitation(ctx context.Context, eventID uuid.UUID, onlyIDs []uuid.UUID) ([]GuestForInvitation, error) { + var ( + rows pgx.Rows + err error + ) + if len(onlyIDs) == 0 { + rows, err = r.pool.Query(ctx, ` + SELECT g.id, g.name, COALESCE(g.email,''), + (t.id IS NOT NULL) AS has_token + FROM guests g + LEFT JOIN tokens t ON t.guest_id = g.id + WHERE g.event_id = $1 + ORDER BY g.created_at ASC + `, eventID) + } else { + rows, err = r.pool.Query(ctx, ` + SELECT g.id, g.name, COALESCE(g.email,''), + (t.id IS NOT NULL) AS has_token + FROM guests g + LEFT JOIN tokens t ON t.guest_id = g.id + WHERE g.event_id = $1 AND g.id = ANY($2) + ORDER BY g.created_at ASC + `, eventID, onlyIDs) + } + if err != nil { + return nil, err + } + defer rows.Close() + + var out []GuestForInvitation + for rows.Next() { + var g GuestForInvitation + if err := rows.Scan(&g.ID, &g.Name, &g.Email, &g.HasToken); err != nil { + return nil, err + } + out = append(out, g) + } + return out, rows.Err() +} + // GuestWithRSVP is the dashboard view: a guest plus the RSVP submitted // against their token, if any. RSVP fields are nil when no response yet. type GuestWithRSVP struct { diff --git a/internal/storage/migrations/0003_auth.down.sql b/internal/storage/migrations/0003_auth.down.sql new file mode 100644 index 0000000..0a82a6e --- /dev/null +++ b/internal/storage/migrations/0003_auth.down.sql @@ -0,0 +1,8 @@ +DROP TABLE IF EXISTS refresh_tokens; +DROP TABLE IF EXISTS password_reset_tokens; +DROP TABLE IF EXISTS email_verification_tokens; + +ALTER TABLE users + DROP COLUMN IF EXISTS email_verified_at, + DROP COLUMN IF EXISTS email_verified, + DROP COLUMN IF EXISTS password_hash; diff --git a/internal/storage/migrations/0003_auth.up.sql b/internal/storage/migrations/0003_auth.up.sql new file mode 100644 index 0000000..06898a3 --- /dev/null +++ b/internal/storage/migrations/0003_auth.up.sql @@ -0,0 +1,37 @@ +ALTER TABLE users + ADD COLUMN IF NOT EXISTS password_hash TEXT, + ADD COLUMN IF NOT EXISTS email_verified BOOLEAN NOT NULL DEFAULT FALSE, + ADD COLUMN IF NOT EXISTS email_verified_at TIMESTAMPTZ; + +CREATE TABLE IF NOT EXISTS email_verification_tokens ( + token_hash TEXT PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_user ON email_verification_tokens(user_id); + +CREATE TABLE IF NOT EXISTS password_reset_tokens ( + token_hash TEXT PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_password_reset_tokens_user ON password_reset_tokens(user_id); + +CREATE TABLE IF NOT EXISTS refresh_tokens ( + token_hash TEXT PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + expires_at TIMESTAMPTZ NOT NULL, + revoked_at TIMESTAMPTZ, + replaced_by TEXT REFERENCES refresh_tokens(token_hash) ON DELETE SET NULL, + user_agent TEXT, + ip_address INET, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_active ON refresh_tokens(user_id) WHERE revoked_at IS NULL; diff --git a/internal/storage/migrations/0004_notifications_d.down.sql b/internal/storage/migrations/0004_notifications_d.down.sql new file mode 100644 index 0000000..0873d56 --- /dev/null +++ b/internal/storage/migrations/0004_notifications_d.down.sql @@ -0,0 +1,8 @@ +DROP TABLE IF EXISTS unsubscribes; + +ALTER TABLE notifications + DROP COLUMN IF EXISTS complained, + DROP COLUMN IF EXISTS bounce_type, + DROP COLUMN IF EXISTS provider_message_id; + +DROP INDEX IF EXISTS idx_notifications_provider_message_id; diff --git a/internal/storage/migrations/0004_notifications_d.up.sql b/internal/storage/migrations/0004_notifications_d.up.sql new file mode 100644 index 0000000..b358bf4 --- /dev/null +++ b/internal/storage/migrations/0004_notifications_d.up.sql @@ -0,0 +1,23 @@ +-- Block D — real notifications: bounce / complaint tracking + suppression list. +-- The `delivered_at` column already exists from 0001. + +CREATE EXTENSION IF NOT EXISTS "citext"; + +ALTER TABLE notifications + ADD COLUMN IF NOT EXISTS provider_message_id TEXT, + ADD COLUMN IF NOT EXISTS bounce_type TEXT, -- 'permanent' | 'transient' | NULL + ADD COLUMN IF NOT EXISTS complained BOOLEAN NOT NULL DEFAULT FALSE; + +CREATE INDEX IF NOT EXISTS idx_notifications_provider_message_id + ON notifications(provider_message_id) + WHERE provider_message_id IS NOT NULL; + +-- Suppression list: any email present here gets a silent no-op on send. +-- Populated by bounce / complaint webhooks and by guest-initiated +-- unsubscribe clicks. +CREATE TABLE IF NOT EXISTS unsubscribes ( + email CITEXT PRIMARY KEY, + reason TEXT, + source TEXT NOT NULL DEFAULT 'manual', -- 'bounce' | 'complaint' | 'manual' | 'user' + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/internal/storage/migrations/0005_billing.down.sql b/internal/storage/migrations/0005_billing.down.sql new file mode 100644 index 0000000..30de797 --- /dev/null +++ b/internal/storage/migrations/0005_billing.down.sql @@ -0,0 +1,4 @@ +DROP INDEX IF EXISTS idx_subscriptions_subscription; +DROP INDEX IF EXISTS idx_subscriptions_customer; +DROP INDEX IF EXISTS uniq_subscriptions_active_user; +DROP TABLE IF EXISTS subscriptions; diff --git a/internal/storage/migrations/0005_billing.up.sql b/internal/storage/migrations/0005_billing.up.sql new file mode 100644 index 0000000..b860d90 --- /dev/null +++ b/internal/storage/migrations/0005_billing.up.sql @@ -0,0 +1,30 @@ +-- Block F — Stripe subscriptions. One row per Stripe customer + (optional) +-- active subscription. Free-tier hosts never get a row; their tier is +-- inferred at read time. + +CREATE TABLE IF NOT EXISTS subscriptions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + stripe_customer_id TEXT NOT NULL, + stripe_subscription_id TEXT, + tier TEXT NOT NULL CHECK (tier IN ('free','pro','business')), + status TEXT NOT NULL CHECK (status IN ('active','past_due','canceled','incomplete','trialing','unpaid')), + current_period_end TIMESTAMPTZ, + cancel_at_period_end BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- A user may have at most one *granting* subscription at a time. We +-- include trialing + past_due because those still convey access (past_due +-- is the grace period before Stripe gives up on the card). +CREATE UNIQUE INDEX IF NOT EXISTS uniq_subscriptions_active_user + ON subscriptions(user_id) + WHERE status IN ('active','past_due','trialing'); + +CREATE INDEX IF NOT EXISTS idx_subscriptions_customer + ON subscriptions(stripe_customer_id); + +CREATE INDEX IF NOT EXISTS idx_subscriptions_subscription + ON subscriptions(stripe_subscription_id) + WHERE stripe_subscription_id IS NOT NULL; diff --git a/internal/storage/migrations/0006_privacy.down.sql b/internal/storage/migrations/0006_privacy.down.sql new file mode 100644 index 0000000..b6602c1 --- /dev/null +++ b/internal/storage/migrations/0006_privacy.down.sql @@ -0,0 +1,6 @@ +DROP INDEX IF EXISTS idx_users_active_email; + +ALTER TABLE users + DROP COLUMN IF EXISTS privacy_policy_accepted_at, + DROP COLUMN IF EXISTS terms_accepted_at, + DROP COLUMN IF EXISTS deleted_at; diff --git a/internal/storage/migrations/0006_privacy.up.sql b/internal/storage/migrations/0006_privacy.up.sql new file mode 100644 index 0000000..13c610d --- /dev/null +++ b/internal/storage/migrations/0006_privacy.up.sql @@ -0,0 +1,20 @@ +-- Block H — privacy compliance. +-- +-- Adds the columns needed for: +-- - Right to erasure (DELETE /me): soft-delete first, hard-delete via +-- a future cron after a 30-day grace window so an accidental click +-- is recoverable. +-- - Terms / privacy-policy acceptance gate (set on signup; older +-- accounts re-prompted via the frontend on next login). + +ALTER TABLE users + ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS terms_accepted_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS privacy_policy_accepted_at TIMESTAMPTZ; + +-- Most lookups (login, /me, etc.) want to exclude soft-deleted users. +-- A partial index keeps the active subset fast without bloating writes +-- for the rare deleted rows. +CREATE INDEX IF NOT EXISTS idx_users_active_email + ON users(email) + WHERE deleted_at IS NULL; diff --git a/internal/storage/subscriptions.go b/internal/storage/subscriptions.go new file mode 100644 index 0000000..45429f6 --- /dev/null +++ b/internal/storage/subscriptions.go @@ -0,0 +1,223 @@ +package storage + +import ( + "context" + "errors" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Subscription mirrors the subscriptions table row. Stored as a thin +// projection of the Stripe state — we don't try to mirror every field, +// just what middleware + handlers need to decide access. +type Subscription struct { + ID uuid.UUID + UserID uuid.UUID + StripeCustomerID string + StripeSubscriptionID *string + Tier string + Status string + CurrentPeriodEnd *time.Time + CancelAtPeriodEnd bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// ErrSubscriptionNotFound is returned when no row matches the lookup. +var ErrSubscriptionNotFound = errors.New("subscription not found") + +type SubscriptionRepo struct { + pool *pgxpool.Pool +} + +func NewSubscriptionRepo(db *DB) *SubscriptionRepo { + return &SubscriptionRepo{pool: db.Pool} +} + +const subscriptionColumns = ` + id, user_id, stripe_customer_id, stripe_subscription_id, + tier, status, current_period_end, cancel_at_period_end, + created_at, updated_at +` + +// GetActiveByUser returns the user's currently-granting subscription +// (active / trialing / past_due). Returns ErrSubscriptionNotFound when +// the user has no row at all — caller treats that as free tier. +func (r *SubscriptionRepo) GetActiveByUser(ctx context.Context, userID uuid.UUID) (*Subscription, error) { + const q = ` + SELECT ` + subscriptionColumns + ` + FROM subscriptions + WHERE user_id = $1 + AND status IN ('active','past_due','trialing') + ORDER BY updated_at DESC + LIMIT 1 + ` + return r.scanOne(ctx, q, userID) +} + +// GetByCustomer fetches by Stripe customer id — webhooks use this since +// the event payload identifies the customer, not the user. +func (r *SubscriptionRepo) GetByCustomer(ctx context.Context, customerID string) (*Subscription, error) { + const q = ` + SELECT ` + subscriptionColumns + ` + FROM subscriptions WHERE stripe_customer_id = $1 + ORDER BY updated_at DESC LIMIT 1 + ` + return r.scanOne(ctx, q, customerID) +} + +// FindCustomerID returns the Stripe customer id we've already created +// for this user, or "" if none exists yet. Avoids creating duplicate +// Stripe customers across checkout sessions. +func (r *SubscriptionRepo) FindCustomerID(ctx context.Context, userID uuid.UUID) (string, error) { + const q = ` + SELECT stripe_customer_id FROM subscriptions + WHERE user_id = $1 ORDER BY created_at ASC LIMIT 1 + ` + var id string + if err := r.pool.QueryRow(ctx, q, userID).Scan(&id); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", nil + } + return "", err + } + return id, nil +} + +// UpsertParams collects everything an upsert needs. Pointer types denote +// "skip writing this column" (used when a webhook only carries partial +// data — we never want to clobber tier or period info we don't have). +type UpsertParams struct { + UserID uuid.UUID + StripeCustomerID string + StripeSubscriptionID *string + Tier *string + Status *string + CurrentPeriodEnd *time.Time + CancelAtPeriodEnd *bool +} + +// Upsert inserts a new row or updates an existing one keyed by +// stripe_customer_id. Used by both the checkout-success handler and the +// webhook subscription-lifecycle handler. +func (r *SubscriptionRepo) Upsert(ctx context.Context, p UpsertParams) (*Subscription, error) { + const q = ` + INSERT INTO subscriptions ( + user_id, stripe_customer_id, stripe_subscription_id, + tier, status, current_period_end, cancel_at_period_end + ) + VALUES ( + $1, $2, $3, + COALESCE($4, 'free'), COALESCE($5, 'incomplete'), + $6, COALESCE($7, FALSE) + ) + ON CONFLICT (id) DO NOTHING + RETURNING ` + subscriptionColumns + ` + ` + + row := r.pool.QueryRow(ctx, q, + p.UserID, p.StripeCustomerID, p.StripeSubscriptionID, + p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd, + ) + sub, err := scanSubscription(row) + if err == nil { + return sub, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return nil, err + } + + // Race or duplicate insert — fall back to an explicit update on the + // stripe_customer_id (the FK to Stripe's source of truth). + const upd = ` + UPDATE subscriptions SET + stripe_subscription_id = COALESCE($3, stripe_subscription_id), + tier = COALESCE($4, tier), + status = COALESCE($5, status), + current_period_end = COALESCE($6, current_period_end), + cancel_at_period_end = COALESCE($7, cancel_at_period_end), + updated_at = now() + WHERE user_id = $1 AND stripe_customer_id = $2 + RETURNING ` + subscriptionColumns + ` + ` + row = r.pool.QueryRow(ctx, upd, + p.UserID, p.StripeCustomerID, p.StripeSubscriptionID, + p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd, + ) + return scanSubscription(row) +} + +// UpdateByCustomer patches the subscription row keyed by Stripe customer +// id. Used by webhooks where we have the customer reference but not +// always the user id. +func (r *SubscriptionRepo) UpdateByCustomer(ctx context.Context, customerID string, p UpsertParams) error { + const q = ` + UPDATE subscriptions SET + stripe_subscription_id = COALESCE($2, stripe_subscription_id), + tier = COALESCE($3, tier), + status = COALESCE($4, status), + current_period_end = COALESCE($5, current_period_end), + cancel_at_period_end = COALESCE($6, cancel_at_period_end), + updated_at = now() + WHERE stripe_customer_id = $1 + ` + _, err := r.pool.Exec(ctx, q, + customerID, p.StripeSubscriptionID, + p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd, + ) + return err +} + +// CountEventsInCurrentMonth returns how many events the user has created +// since the 1st of the current UTC month. Used for free-tier "1 event / +// month" and Pro-tier "10 events / month" enforcement. +func (r *SubscriptionRepo) CountEventsInCurrentMonth(ctx context.Context, userID uuid.UUID) (int, error) { + const q = ` + SELECT count(*) FROM events + WHERE host_id = $1 + AND created_at >= date_trunc('month', now() AT TIME ZONE 'UTC') + ` + var n int + if err := r.pool.QueryRow(ctx, q, userID).Scan(&n); err != nil { + return 0, err + } + return n, nil +} + +// CountGuestsByEvent returns the current guest count for an event. Used +// for per-event guest cap enforcement. +func (r *SubscriptionRepo) CountGuestsByEvent(ctx context.Context, eventID uuid.UUID) (int, error) { + var n int + if err := r.pool.QueryRow(ctx, + `SELECT count(*) FROM guests WHERE event_id = $1`, eventID, + ).Scan(&n); err != nil { + return 0, err + } + return n, nil +} + +func (r *SubscriptionRepo) scanOne(ctx context.Context, q string, args ...any) (*Subscription, error) { + sub, err := scanSubscription(r.pool.QueryRow(ctx, q, args...)) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrSubscriptionNotFound + } + return nil, err + } + return sub, nil +} + +func scanSubscription(s rowScanner) (*Subscription, error) { + var sub Subscription + if err := s.Scan( + &sub.ID, &sub.UserID, &sub.StripeCustomerID, &sub.StripeSubscriptionID, + &sub.Tier, &sub.Status, &sub.CurrentPeriodEnd, &sub.CancelAtPeriodEnd, + &sub.CreatedAt, &sub.UpdatedAt, + ); err != nil { + return nil, err + } + return &sub, nil +} diff --git a/internal/storage/tokens.go b/internal/storage/tokens.go index 5333fa8..8a999a9 100644 --- a/internal/storage/tokens.go +++ b/internal/storage/tokens.go @@ -60,6 +60,38 @@ func (r *TokenRepo) GetByHash(ctx context.Context, hash string) (*domain.Token, return tk, nil } +// RotateForGuest replaces the guest's existing token with a freshly-minted +// one in a single transaction. The old token row is hard-deleted (the +// guests.id UNIQUE constraint requires it, and "the old link must stop +// working" is the point). Cascade-deletes the old access_logs rows that +// reference it via the token_id FK with ON DELETE SET NULL — those rows +// stay, with token_id nulled. +func (r *TokenRepo) RotateForGuest(ctx context.Context, p CreateTokenParams) (*domain.Token, error) { + tx, err := r.pool.Begin(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback(ctx) + + if _, err := tx.Exec(ctx, `DELETE FROM tokens WHERE guest_id = $1`, p.GuestID); err != nil { + return nil, err + } + const q = ` + INSERT INTO tokens (guest_id, token_hash, expires_at, status) + VALUES ($1, $2, $3, 'active') + RETURNING id, guest_id, token_hash, expires_at, status, used_at, created_at + ` + row := tx.QueryRow(ctx, q, p.GuestID, p.TokenHash, p.ExpiresAt) + tk, err := scanToken(row) + if err != nil { + return nil, err + } + if err := tx.Commit(ctx); err != nil { + return nil, err + } + return tk, nil +} + func (r *TokenRepo) MarkUsed(ctx context.Context, id uuid.UUID) error { tag, err := r.pool.Exec(ctx, ` UPDATE tokens SET status = 'used', used_at = now() diff --git a/internal/storage/users.go b/internal/storage/users.go index 432d943..1f9b19a 100644 --- a/internal/storage/users.go +++ b/internal/storage/users.go @@ -5,6 +5,7 @@ import ( "errors" "strings" + "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" @@ -20,12 +21,38 @@ func NewUserRepo(db *DB) *UserRepo { return &UserRepo{pool: db.Pool} } -func (r *UserRepo) Create(ctx context.Context, email, name string) (*domain.User, error) { +const userColumns = `id, email, name, + COALESCE(password_hash, '') AS password_hash, + email_verified, email_verified_at, + deleted_at, + terms_accepted_at, privacy_policy_accepted_at, + created_at, updated_at` + +type CreateUserParams struct { + Email string + Name string + PasswordHash string + AcceptTerms bool // when true, records terms + privacy acceptance now +} + +func (r *UserRepo) Create(ctx context.Context, p CreateUserParams) (*domain.User, error) { const q = ` - INSERT INTO users (email, name) VALUES ($1, $2) - RETURNING id, email, name, created_at, updated_at - ` - row := r.pool.QueryRow(ctx, q, strings.ToLower(strings.TrimSpace(email)), strings.TrimSpace(name)) + INSERT INTO users ( + email, name, password_hash, + terms_accepted_at, privacy_policy_accepted_at + ) + VALUES ( + $1, $2, NULLIF($3, ''), + CASE WHEN $4 THEN now() ELSE NULL END, + CASE WHEN $4 THEN now() ELSE NULL END + ) + RETURNING ` + userColumns + row := r.pool.QueryRow(ctx, q, + normaliseEmail(p.Email), + strings.TrimSpace(p.Name), + p.PasswordHash, + p.AcceptTerms, + ) u, err := scanUser(row) if err != nil { var pgErr *pgconn.PgError @@ -37,9 +64,12 @@ func (r *UserRepo) Create(ctx context.Context, email, name string) (*domain.User return u, nil } -func (r *UserRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) { - const q = `SELECT id, email, name, created_at, updated_at FROM users WHERE email = $1` - u, err := scanUser(r.pool.QueryRow(ctx, q, strings.ToLower(strings.TrimSpace(email)))) +// GetByID returns an active (non-soft-deleted) user. Soft-deleted users +// are treated as "not found" by the API surface — keeps the deletion +// flow safe by default. +func (r *UserRepo) GetByID(ctx context.Context, id uuid.UUID) (*domain.User, error) { + const q = `SELECT ` + userColumns + ` FROM users WHERE id = $1 AND deleted_at IS NULL` + u, err := scanUser(r.pool.QueryRow(ctx, q, id)) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, domain.ErrUserNotFound @@ -49,10 +79,112 @@ func (r *UserRepo) GetByEmail(ctx context.Context, email string) (*domain.User, return u, nil } +// GetByEmail mirrors GetByID — soft-deleted users vanish from email +// lookups (so signup/login don't match a tombstoned record). +func (r *UserRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) { + const q = `SELECT ` + userColumns + ` FROM users WHERE email = $1 AND deleted_at IS NULL` + u, err := scanUser(r.pool.QueryRow(ctx, q, normaliseEmail(email))) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, domain.ErrUserNotFound + } + return nil, err + } + return u, nil +} + +// SoftDelete marks the user as deleted and clears their PII-bearing +// fields. A nightly cron (TBD in ops) will hard-delete rows older than +// 30 days. Until then the row exists for audit + recovery if the user +// changes their mind. +func (r *UserRepo) SoftDelete(ctx context.Context, id uuid.UUID) error { + tag, err := r.pool.Exec(ctx, ` + UPDATE users SET + deleted_at = now(), + updated_at = now(), + -- Tombstone PII so the soft-deleted row can sit for 30 days + -- without holding the user's real email + name in cleartext. + -- The original values are gone from the API surface from the + -- moment SoftDelete returns. + email = 'deleted-' || id::text || '@deleted.local', + name = 'Deleted user', + password_hash = NULL + WHERE id = $1 AND deleted_at IS NULL + `, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return domain.ErrUserNotFound + } + return nil +} + +// AcceptTerms records that the user has consented to the current terms +// of service and privacy policy. Idempotent — re-accepting just resets +// the timestamp. +func (r *UserRepo) AcceptTerms(ctx context.Context, id uuid.UUID) error { + tag, err := r.pool.Exec(ctx, ` + UPDATE users SET + terms_accepted_at = now(), + privacy_policy_accepted_at = now(), + updated_at = now() + WHERE id = $1 AND deleted_at IS NULL + `, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return domain.ErrUserNotFound + } + return nil +} + +func (r *UserRepo) MarkEmailVerified(ctx context.Context, id uuid.UUID) error { + tag, err := r.pool.Exec(ctx, ` + UPDATE users + SET email_verified = TRUE, + email_verified_at = COALESCE(email_verified_at, now()), + updated_at = now() + WHERE id = $1 + `, id) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return domain.ErrUserNotFound + } + return nil +} + +func (r *UserRepo) UpdatePasswordHash(ctx context.Context, id uuid.UUID, hash string) error { + tag, err := r.pool.Exec(ctx, ` + UPDATE users SET password_hash = $2, updated_at = now() WHERE id = $1 + `, id, hash) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return domain.ErrUserNotFound + } + return nil +} + func scanUser(s rowScanner) (*domain.User, error) { var u domain.User - if err := s.Scan(&u.ID, &u.Email, &u.Name, &u.CreatedAt, &u.UpdatedAt); err != nil { + if err := s.Scan( + &u.ID, &u.Email, &u.Name, + &u.PasswordHash, + &u.EmailVerified, &u.EmailVerifiedAt, + &u.DeletedAt, + &u.TermsAcceptedAt, &u.PrivacyPolicyAcceptedAt, + &u.CreatedAt, &u.UpdatedAt, + ); err != nil { return nil, err } return &u, nil } + +func normaliseEmail(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} diff --git a/test/integration/auth_test.go b/test/integration/auth_test.go new file mode 100644 index 0000000..eb840e6 --- /dev/null +++ b/test/integration/auth_test.go @@ -0,0 +1,378 @@ +//go:build integration + +package integration_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/storage" +) + +const authTestPassword = "correct-horse-battery-staple" + +// recordingEmailSender captures the most recent verification / reset link so +// tests can finish the signup flow without a real inbox. +type recordingEmailSender struct { + verifyLink string + resetLink string +} + +func (s *recordingEmailSender) SendVerification(_ context.Context, _, _, link string) error { + s.verifyLink = link + return nil +} + +func (s *recordingEmailSender) SendPasswordReset(_ context.Context, _, _, link string) error { + s.resetLink = link + return nil +} + +func TestAuthFlow(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + emails := &recordingEmailSender{} + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: "test-secret-must-be-at-least-32-bytes-long-xx", + JWTIssuer: "guestguard-test", + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + EmailSender: emails, + }) + must(t, err, "build api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + email := uniqueEmail(t) + + t.Run("signup", func(t *testing.T) { + resp := post(t, client, srv.URL+"/auth/signup", map[string]string{ + "email": email, + "name": "Auth Test", + "password": authTestPassword, + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("signup status: %d", resp.StatusCode) + } + resp.Body.Close() + if emails.verifyLink == "" { + t.Fatal("verification email not captured") + } + }) + + t.Run("login before verify is forbidden", func(t *testing.T) { + resp := post(t, client, srv.URL+"/auth/login", map[string]string{ + "email": email, + "password": authTestPassword, + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 403, got %d: %s", resp.StatusCode, body) + } + }) + + t.Run("verify email", func(t *testing.T) { + token := tokenFromQuery(t, emails.verifyLink, "token") + resp := post(t, client, srv.URL+"/auth/verify-email", map[string]string{"token": token}) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("verify status: %d %s", resp.StatusCode, body) + } + }) + + t.Run("verify token replay rejected", func(t *testing.T) { + token := tokenFromQuery(t, emails.verifyLink, "token") + resp := post(t, client, srv.URL+"/auth/verify-email", map[string]string{"token": token}) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("replay should be 400, got %d", resp.StatusCode) + } + }) + + var firstAccess string + t.Run("login returns access + refresh cookie", func(t *testing.T) { + resp := post(t, client, srv.URL+"/auth/login", map[string]string{ + "email": email, + "password": authTestPassword, + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("login status: %d %s", resp.StatusCode, body) + } + var body struct { + AccessToken string `json:"access_token"` + } + must(t, json.NewDecoder(resp.Body).Decode(&body), "decode login") + if body.AccessToken == "" { + t.Fatal("missing access token") + } + firstAccess = body.AccessToken + assertRefreshCookieSet(t, srv.URL, jar) + }) + + t.Run("access token authorises /me", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/me", nil) + req.Header.Set("Authorization", "Bearer "+firstAccess) + resp, err := client.Do(req) + must(t, err, "GET /me") + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("/me status: %d %s", resp.StatusCode, body) + } + }) + + t.Run("refresh rotates tokens", func(t *testing.T) { + oldCookie := refreshCookieValue(t, srv.URL, jar) + + resp := post(t, client, srv.URL+"/auth/refresh", nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("refresh status: %d %s", resp.StatusCode, body) + } + var body struct { + AccessToken string `json:"access_token"` + } + must(t, json.NewDecoder(resp.Body).Decode(&body), "decode refresh") + if body.AccessToken == "" { + t.Fatal("missing new access token") + } + newCookie := refreshCookieValue(t, srv.URL, jar) + if newCookie == oldCookie { + t.Fatal("refresh did not rotate cookie") + } + + // Replay of the old refresh token must be rejected and revoke the family. + jar2, _ := cookiejar.New(nil) + client2 := &http.Client{Jar: jar2} + setRefreshCookie(t, srv.URL, jar2, oldCookie) + replay := post(t, client2, srv.URL+"/auth/refresh", nil) + replay.Body.Close() + if replay.StatusCode != http.StatusUnauthorized { + t.Fatalf("old refresh replay should be 401, got %d", replay.StatusCode) + } + + // And the family should be revoked: even the new (just-rotated) cookie + // no longer works. + familyReplay := post(t, client, srv.URL+"/auth/refresh", nil) + familyReplay.Body.Close() + if familyReplay.StatusCode != http.StatusUnauthorized { + t.Fatalf("family-revoked refresh should be 401, got %d", familyReplay.StatusCode) + } + }) + + // After family revocation, log back in to keep going. + resp := post(t, client, srv.URL+"/auth/login", map[string]string{ + "email": email, + "password": authTestPassword, + }) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("second login: %d", resp.StatusCode) + } + + t.Run("forgot-password emits link without leaking existence", func(t *testing.T) { + // Unknown email — still 202, no link sent. + emails.resetLink = "" + unknown := post(t, client, srv.URL+"/auth/forgot-password", map[string]string{ + "email": "nobody-" + uuid.NewString() + "@guestguard.test", + }) + unknown.Body.Close() + if unknown.StatusCode != http.StatusAccepted { + t.Fatalf("unknown forgot-password: %d", unknown.StatusCode) + } + if emails.resetLink != "" { + t.Fatal("reset link sent for unknown email") + } + + known := post(t, client, srv.URL+"/auth/forgot-password", map[string]string{ + "email": email, + }) + known.Body.Close() + if known.StatusCode != http.StatusAccepted { + t.Fatalf("known forgot-password: %d", known.StatusCode) + } + if emails.resetLink == "" { + t.Fatal("reset link not captured") + } + }) + + t.Run("reset password invalidates sessions", func(t *testing.T) { + token := tokenFromPath(t, emails.resetLink, "/reset-password/") + newPw := "new-correct-horse-battery-staple" + resp := post(t, client, srv.URL+"/auth/reset-password", map[string]string{ + "token": token, + "new_password": newPw, + }) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("reset status: %d", resp.StatusCode) + } + + // Old password fails. + bad := post(t, client, srv.URL+"/auth/login", map[string]string{ + "email": email, + "password": authTestPassword, + }) + bad.Body.Close() + if bad.StatusCode != http.StatusUnauthorized { + t.Fatalf("old password should 401, got %d", bad.StatusCode) + } + + // Existing refresh cookie should no longer work. + refresh := post(t, client, srv.URL+"/auth/refresh", nil) + refresh.Body.Close() + if refresh.StatusCode != http.StatusUnauthorized { + t.Fatalf("refresh after reset should 401, got %d", refresh.StatusCode) + } + + // New password works. + ok := post(t, client, srv.URL+"/auth/login", map[string]string{ + "email": email, + "password": newPw, + }) + ok.Body.Close() + if ok.StatusCode != http.StatusOK { + t.Fatalf("new password login: %d", ok.StatusCode) + } + }) + + t.Run("logout revokes refresh", func(t *testing.T) { + resp := post(t, client, srv.URL+"/auth/logout", nil) + resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("logout status: %d", resp.StatusCode) + } + refresh := post(t, client, srv.URL+"/auth/refresh", nil) + refresh.Body.Close() + if refresh.StatusCode != http.StatusUnauthorized { + t.Fatalf("refresh after logout should 401, got %d", refresh.StatusCode) + } + }) + + t.Run("requireAuth rejects invalid bearer", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/me", nil) + req.Header.Set("Authorization", "Bearer not-a-real-jwt") + resp, err := client.Do(req) + must(t, err, "GET /me bad token") + resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("bad bearer should 401, got %d", resp.StatusCode) + } + }) +} + +// --- helpers --- + +func uniqueEmail(t *testing.T) string { + t.Helper() + return "auth-" + uuid.NewString() + "@guestguard.test" +} + +func post(t *testing.T, client *http.Client, url string, body any) *http.Response { + t.Helper() + var r io.Reader + if body != nil { + b, err := json.Marshal(body) + must(t, err, "marshal post body") + r = bytes.NewReader(b) + } + req, err := http.NewRequest(http.MethodPost, url, r) + must(t, err, "build post request") + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := client.Do(req) + must(t, err, "do post "+url) + return resp +} + +func tokenFromQuery(t *testing.T, link, key string) string { + t.Helper() + idx := strings.Index(link, key+"=") + if idx < 0 { + t.Fatalf("link missing %s: %s", key, link) + } + return link[idx+len(key)+1:] +} + +func tokenFromPath(t *testing.T, link, prefix string) string { + t.Helper() + idx := strings.LastIndex(link, prefix) + if idx < 0 { + t.Fatalf("link missing prefix %s: %s", prefix, link) + } + return link[idx+len(prefix):] +} + +func assertRefreshCookieSet(t *testing.T, baseURL string, jar http.CookieJar) { + t.Helper() + if refreshCookieValue(t, baseURL, jar) == "" { + t.Fatal("refresh cookie not set") + } +} + +func refreshCookieValue(t *testing.T, baseURL string, jar http.CookieJar) string { + t.Helper() + // jar.Cookies needs a URL whose path matches the cookie's Path (/auth). + u := baseURL + "/auth/refresh" + parsed, err := url.Parse(u) + must(t, err, "parse url") + for _, c := range jar.Cookies(parsed) { + if c.Name == "gg_refresh" { + return c.Value + } + } + return "" +} + +func setRefreshCookie(t *testing.T, baseURL string, jar http.CookieJar, value string) { + t.Helper() + parsed, err := url.Parse(baseURL + "/auth/refresh") + must(t, err, "parse url") + jar.SetCookies(parsed, []*http.Cookie{{ + Name: "gg_refresh", + Value: value, + Path: "/auth", + }}) +} diff --git a/test/integration/authz_test.go b/test/integration/authz_test.go new file mode 100644 index 0000000..ea88051 --- /dev/null +++ b/test/integration/authz_test.go @@ -0,0 +1,220 @@ +//go:build integration + +package integration_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestCrossTenantIsolation confirms that one authenticated host cannot +// read, modify, or extend another host's event. All cross-tenant attempts +// should return 404 — never 403 — so a probe can't tell whether a given +// UUID exists at all. +func TestCrossTenantIsolation(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + }) + must(t, err, "build api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostA := insertHost(t, ctx, db.Pool) + hostB := insertHost(t, ctx, db.Pool) + tokenA := issueHostToken(t, hostA) + tokenB := issueHostToken(t, hostB) + + eventA := createEvent(t, srv.URL, tokenA, "Host A's Event", "host-a-event") + + t.Run("list returns only own events", func(t *testing.T) { + out := struct { + Events []struct { + ID uuid.UUID `json:"id"` + } `json:"events"` + }{} + getJSONAuthed(t, srv.URL+"/events", tokenB, http.StatusOK, &out) + for _, e := range out.Events { + if e.ID == eventA { + t.Fatalf("host B saw host A's event in /events list") + } + } + }) + + t.Run("GET other host's event is 404", func(t *testing.T) { + assertStatus(t, http.MethodGet, fmt.Sprintf("%s/events/%s", srv.URL, eventA), + tokenB, nil, http.StatusNotFound) + }) + + t.Run("PATCH other host's event is 404", func(t *testing.T) { + body := map[string]any{"name": "hijacked"} + assertStatus(t, http.MethodPatch, fmt.Sprintf("%s/events/%s", srv.URL, eventA), + tokenB, body, http.StatusNotFound) + }) + + t.Run("DELETE other host's event is 404", func(t *testing.T) { + assertStatus(t, http.MethodDelete, fmt.Sprintf("%s/events/%s", srv.URL, eventA), + tokenB, nil, http.StatusNotFound) + }) + + t.Run("POST guest on other host's event is 404", func(t *testing.T) { + body := map[string]any{"name": "Mallory"} + assertStatus(t, http.MethodPost, fmt.Sprintf("%s/events/%s/guests", srv.URL, eventA), + tokenB, body, http.StatusNotFound) + }) + + t.Run("GET guests on other host's event is 404", func(t *testing.T) { + assertStatus(t, http.MethodGet, fmt.Sprintf("%s/events/%s/guests", srv.URL, eventA), + tokenB, nil, http.StatusNotFound) + }) + + t.Run("GET activity on other host's event is 404", func(t *testing.T) { + assertStatus(t, http.MethodGet, fmt.Sprintf("%s/events/%s/activity", srv.URL, eventA), + tokenB, nil, http.StatusNotFound) + }) + + t.Run("no bearer is 401", func(t *testing.T) { + assertStatus(t, http.MethodGet, srv.URL+"/events", "", nil, http.StatusUnauthorized) + }) + + t.Run("ws-ticket for other host's event is 404", func(t *testing.T) { + body := map[string]any{"event_id": eventA.String()} + assertStatus(t, http.MethodPost, srv.URL+"/auth/ws-ticket", + tokenB, body, http.StatusNotFound) + }) + + t.Run("ws-ticket then WS handshake requires matching path", func(t *testing.T) { + // Host A mints a ticket for their own event. + var ticketResp struct { + Ticket string `json:"ticket"` + } + postJSONAuthed(t, srv.URL+"/auth/ws-ticket", tokenA, + map[string]any{"event_id": eventA.String()}, + http.StatusOK, &ticketResp) + if ticketResp.Ticket == "" { + t.Fatal("empty ticket") + } + + // Trying to use that ticket on a *different* event id must fail. + bogus := uuid.New() + req, _ := http.NewRequest(http.MethodGet, + fmt.Sprintf("%s/ws/events/%s?ticket=%s", srv.URL, bogus, ticketResp.Ticket), nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + resp, err := http.DefaultClient.Do(req) + must(t, err, "WS handshake") + resp.Body.Close() + if resp.StatusCode != http.StatusForbidden && resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 403 or 401, got %d", resp.StatusCode) + } + }) + + t.Run("replaying a consumed ticket fails", func(t *testing.T) { + var ticketResp struct { + Ticket string `json:"ticket"` + } + postJSONAuthed(t, srv.URL+"/auth/ws-ticket", tokenA, + map[string]any{"event_id": eventA.String()}, + http.StatusOK, &ticketResp) + + // First handshake against the correct event id — consumes the ticket. + req1, _ := http.NewRequest(http.MethodGet, + fmt.Sprintf("%s/ws/events/%s?ticket=%s", srv.URL, eventA, ticketResp.Ticket), nil) + req1.Header.Set("Upgrade", "websocket") + req1.Header.Set("Connection", "Upgrade") + resp1, err := http.DefaultClient.Do(req1) + must(t, err, "WS handshake 1") + resp1.Body.Close() + + // Replay — ticket is already consumed. + req2, _ := http.NewRequest(http.MethodGet, + fmt.Sprintf("%s/ws/events/%s?ticket=%s", srv.URL, eventA, ticketResp.Ticket), nil) + req2.Header.Set("Upgrade", "websocket") + req2.Header.Set("Connection", "Upgrade") + resp2, err := http.DefaultClient.Do(req2) + must(t, err, "WS handshake 2") + resp2.Body.Close() + if resp2.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401 on replay, got %d", resp2.StatusCode) + } + }) +} + +func getJSONAuthed(t *testing.T, url, bearer string, wantStatus int, out any) { + t.Helper() + req, err := http.NewRequest(http.MethodGet, url, nil) + must(t, err, "build get") + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } + resp, err := http.DefaultClient.Do(req) + must(t, err, "do get "+url) + defer resp.Body.Close() + if resp.StatusCode != wantStatus { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("%s status=%d want=%d body=%s", url, resp.StatusCode, wantStatus, body) + } + if out != nil { + must(t, json.NewDecoder(resp.Body).Decode(out), "decode response from "+url) + } +} + +func assertStatus(t *testing.T, method, url, bearer string, body any, wantStatus int) { + t.Helper() + var rdr io.Reader + if body != nil { + b, _ := json.Marshal(body) + rdr = bytes.NewReader(b) + } + req, err := http.NewRequest(method, url, rdr) + must(t, err, "build "+method+" "+url) + if rdr != nil { + req.Header.Set("Content-Type", "application/json") + } + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } + resp, err := http.DefaultClient.Do(req) + must(t, err, "do "+method+" "+url) + defer resp.Body.Close() + if resp.StatusCode != wantStatus { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("%s %s status=%d want=%d body=%s", method, url, resp.StatusCode, wantStatus, b) + } +} diff --git a/test/integration/billing_test.go b/test/integration/billing_test.go new file mode 100644 index 0000000..fcb80b0 --- /dev/null +++ b/test/integration/billing_test.go @@ -0,0 +1,190 @@ +//go:build integration + +package integration_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestFreeTierEventLimit confirms free-tier hosts are capped at one +// event per calendar month with a 402 response carrying the upgrade +// payload the frontend uses to render the modal. +func TestFreeTierEventLimit(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost:3000", + }) + must(t, err, "api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + // Bypass insertHost (which auto-grants Business) so this host stays free. + hostID := insertFreeTierHost(t, ctx, db.Pool) + token := issueHostToken(t, hostID) + + // First event under the limit — should succeed. + _ = createEvent(t, srv.URL, token, "First", "free-first") + + // Second event must be 402 with the upgrade payload. + body, _ := json.Marshal(map[string]any{ + "name": "Second", + "slug": "free-second", + "event_date": time.Now().Add(30 * 24 * time.Hour).UTC().Format(time.RFC3339), + "venue": "Hall", + }) + req, _ := http.NewRequest(http.MethodPost, srv.URL+"/events", strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + must(t, err, "POST /events second") + defer resp.Body.Close() + if resp.StatusCode != http.StatusPaymentRequired { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 402 on 2nd event, got %d body=%s", resp.StatusCode, b) + } + var rb struct { + Error string `json:"error"` + Tier string `json:"tier"` + Used int `json:"used"` + Limit int `json:"limit"` + UpgradeURL string `json:"upgrade_url"` + } + must(t, json.NewDecoder(resp.Body).Decode(&rb), "decode 402 body") + if rb.Tier != "free" || rb.Used != 1 || rb.Limit != 1 { + t.Errorf("402 payload: %+v", rb) + } + if !strings.Contains(rb.UpgradeURL, "/dashboard/billing") { + t.Errorf("expected upgrade_url to point at /dashboard/billing, got %q", rb.UpgradeURL) + } +} + +// TestFreeTierGuestLimit confirms the per-event guest cap kicks in at +// the right number, again with a 402 payload. +func TestFreeTierGuestLimit(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost:3000", + }) + must(t, err, "api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostID := insertFreeTierHost(t, ctx, db.Pool) + token := issueHostToken(t, hostID) + eventID := createEvent(t, srv.URL, token, "Free Event", "free-guests-event") + + // Free tier allows 50 guests per event. Seed 50 directly so we don't + // pay 50 HTTP round-trips, then attempt one more via the API. + for i := 0; i < 50; i++ { + _, err := db.Pool.Exec(ctx, + `INSERT INTO guests (event_id, name) VALUES ($1, $2)`, + eventID, fmt.Sprintf("Seeded %d", i), + ) + must(t, err, "seed guest") + } + + // 51st guest must be 402. + body, _ := json.Marshal(map[string]any{"name": "Overflow"}) + req, _ := http.NewRequest(http.MethodPost, + fmt.Sprintf("%s/events/%s/guests", srv.URL, eventID), + strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + must(t, err, "POST /guests overflow") + defer resp.Body.Close() + if resp.StatusCode != http.StatusPaymentRequired { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 402 on 51st guest, got %d body=%s", resp.StatusCode, b) + } +} + +// TestBusinessTierBypassesLimits sanity-checks that a host with an +// active Business subscription can create more than the free-tier +// allowance — the enforcer code path that returns "unlimited". +func TestBusinessTierBypassesLimits(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, _, _, token := setupAuthedAPI(t, ctx) + // setupAuthedAPI already grants Business via insertHost. + for i := 0; i < 3; i++ { + _ = createEvent(t, srv.URL, token, fmt.Sprintf("Biz Event %d", i), + fmt.Sprintf("biz-event-%d", i)) + } +} + +// insertFreeTierHost mints a verified user WITHOUT granting any +// subscription — opposite of the default test helper. Used to exercise +// the free-tier enforcement path. +func insertFreeTierHost(t *testing.T, ctx context.Context, pool *pgxpool.Pool) uuid.UUID { + t.Helper() + var id uuid.UUID + err := pool.QueryRow(ctx, + `INSERT INTO users (email, name, email_verified, email_verified_at) + VALUES ($1, 'Free Tier', TRUE, now()) RETURNING id`, + fmt.Sprintf("free-%d@guestguard.test", time.Now().UnixNano()), + ).Scan(&id) + must(t, err, "insert free-tier host") + return id +} diff --git a/test/integration/bulk_invitation_test.go b/test/integration/bulk_invitation_test.go new file mode 100644 index 0000000..514bde9 --- /dev/null +++ b/test/integration/bulk_invitation_test.go @@ -0,0 +1,234 @@ +//go:build integration + +package integration_test + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/natspub" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestBulkIssueInvitations confirms the bulk endpoint: +// - mints tokens for guests without one +// - skips guests that already have a token +// - publishes invitation.send only for guests with an email +// - returns an accurate per-bucket summary +func TestBulkIssueInvitations(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + dsn := startPostgres(t, ctx) + natsURL := startNATS(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + natsClient, err := natspub.Connect(ctx, natsURL, logger) + must(t, err, "connect nats") + t.Cleanup(natsClient.Close) + + var pubCount atomic.Int32 + emails := make(chan string, 32) + sub, err := natspub.NewInvitationSendSubscriber(ctx, natsClient, "test-bulk-invitation", + func(_ context.Context, evt natspub.InvitationSend) error { + pubCount.Add(1) + select { + case emails <- evt.GuestEmail: + default: + } + return nil + }, logger) + must(t, err, "build subscriber") + cc, err := sub.Start(ctx) + must(t, err, "start subscriber") + t.Cleanup(cc.Stop) + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + InvitationPublisher: natsClient, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "https://gg.example.test", + }) + must(t, err, "build api") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostID := insertHost(t, ctx, db.Pool) + hostToken := issueHostToken(t, hostID) + eventID := createEvent(t, srv.URL, hostToken, "Bulk Event", "bulk-event") + + // Three guests with emails, two without. + var withEmail []uuid.UUID + for i := 0; i < 3; i++ { + id, _ := createGuestWithEmail(t, srv.URL, hostToken, eventID, fmt.Sprintf("Email-%d", i)) + withEmail = append(withEmail, id) + } + var noEmail []uuid.UUID + for i := 0; i < 2; i++ { + noEmail = append(noEmail, createGuest(t, srv.URL, hostToken, eventID, fmt.Sprintf("Phone-%d", i))) + } + + // Pre-issue one token for the first email guest to test the + // skipped_existing path. + issueToken(t, srv.URL, hostToken, eventID, withEmail[0]) + + // Call bulk endpoint (empty body = "all eligible"). + var result struct { + Issued int `json:"issued"` + Queued int `json:"queued"` + SkippedExisting int `json:"skipped_existing"` + SkippedNoEmail int `json:"skipped_no_email"` + Errors []struct { + GuestID string `json:"guest_id"` + Reason string `json:"reason"` + } `json:"errors"` + } + postJSONAuthed(t, fmt.Sprintf("%s/events/%s/guests/invitations/bulk", srv.URL, eventID), + hostToken, map[string]any{}, http.StatusOK, &result) + + if result.Issued != 4 { + t.Errorf("issued: got %d want 4 (2 emails + 2 no-email; the third email guest was pre-issued)", result.Issued) + } + if result.Queued != 2 { + t.Errorf("queued: got %d want 2", result.Queued) + } + if result.SkippedExisting != 1 { + t.Errorf("skipped_existing: got %d want 1", result.SkippedExisting) + } + if result.SkippedNoEmail != 2 { + t.Errorf("skipped_no_email: got %d want 2", result.SkippedNoEmail) + } + if len(result.Errors) != 0 { + t.Errorf("unexpected errors: %+v", result.Errors) + } + + // Wait for the two NATS messages. + deadline := time.After(10 * time.Second) + received := 0 +loop: + for received < 2 { + select { + case <-emails: + received++ + case <-deadline: + break loop + } + } + if received != 2 { + t.Fatalf("expected 2 invitation.send messages, got %d", received) + } + + // Re-running bulk should be a no-op (everyone now has a token). + var second struct { + Issued int `json:"issued"` + Queued int `json:"queued"` + SkippedExisting int `json:"skipped_existing"` + } + postJSONAuthed(t, fmt.Sprintf("%s/events/%s/guests/invitations/bulk", srv.URL, eventID), + hostToken, map[string]any{}, http.StatusOK, &second) + if second.Issued != 0 || second.Queued != 0 || second.SkippedExisting != 5 { + t.Errorf("re-run: got issued=%d queued=%d skipped_existing=%d (want 0/0/5)", + second.Issued, second.Queued, second.SkippedExisting) + } +} + +// TestBulkIssueExplicitSubset confirms guest_ids is honoured. +func TestBulkIssueExplicitSubset(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + dsn := startPostgres(t, ctx) + natsURL := startNATS(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + natsClient, err := natspub.Connect(ctx, natsURL, logger) + must(t, err, "connect nats") + t.Cleanup(natsClient.Close) + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + InvitationPublisher: natsClient, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "https://gg.example.test", + }) + must(t, err, "build api") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostID := insertHost(t, ctx, db.Pool) + hostToken := issueHostToken(t, hostID) + eventID := createEvent(t, srv.URL, hostToken, "Subset Event", "subset-event") + + var ids []uuid.UUID + for i := 0; i < 3; i++ { + id, _ := createGuestWithEmail(t, srv.URL, hostToken, eventID, fmt.Sprintf("G-%d", i)) + ids = append(ids, id) + } + + // Send to only the middle guest. + var result struct { + Issued int `json:"issued"` + Queued int `json:"queued"` + } + postJSONAuthed(t, + fmt.Sprintf("%s/events/%s/guests/invitations/bulk", srv.URL, eventID), + hostToken, + map[string]any{"guest_ids": []string{ids[1].String()}}, + http.StatusOK, &result) + if result.Issued != 1 || result.Queued != 1 { + t.Fatalf("subset send: got issued=%d queued=%d, want 1/1", result.Issued, result.Queued) + } + + // The other two should still be tokenless. + var hasToken int + must(t, db.Pool.QueryRow(ctx, + "SELECT COUNT(*) FROM tokens WHERE guest_id = ANY($1)", []uuid.UUID{ids[0], ids[2]}, + ).Scan(&hasToken), "count tokens for non-targets") + if hasToken != 0 { + t.Fatalf("expected 0 tokens for non-targets, got %d", hasToken) + } +} diff --git a/test/integration/csv_import_test.go b/test/integration/csv_import_test.go new file mode 100644 index 0000000..c1c5092 --- /dev/null +++ b/test/integration/csv_import_test.go @@ -0,0 +1,186 @@ +//go:build integration + +package integration_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestCsvImportFlow walks the happy path: preview, then commit, then a +// re-import to confirm dedup is honoured. +func TestCsvImportFlow(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, host, token := setupAuthedAPI(t, ctx) + eventID := createEvent(t, srv.URL, token, "Import Event", "import-event") + _ = host + + const csvOne = `name,email,phone,plus_ones +Alex,alex@example.test,+447700900111,1 +Sam,sam@example.test,,0 +Jordan,,+15551234567,2 +,,, +Mira,malformed-email,,0 +` + + // Preview. + var preview struct { + Rows []map[string]any `json:"rows"` + Errors []map[string]any `json:"errors"` + TotalCount int `json:"total_count"` + } + resp := postMultipart(t, srv.URL+"/events/"+eventID.String()+"/guests/import/preview", token, csvOne) + must(t, json.NewDecoder(resp.Body).Decode(&preview), "decode preview") + resp.Body.Close() + if len(preview.Rows) != 3 { + t.Fatalf("preview rows: got %d want 3 (errors=%+v)", len(preview.Rows), preview.Errors) + } + if len(preview.Errors) != 1 { + t.Fatalf("preview errors: got %d want 1: %+v", len(preview.Errors), preview.Errors) + } + + // Commit. + var commit struct { + Added int `json:"added"` + Skipped int `json:"skipped"` + SkippedEmails []string `json:"skipped_emails"` + } + resp = postMultipart(t, srv.URL+"/events/"+eventID.String()+"/guests/import", token, csvOne) + must(t, json.NewDecoder(resp.Body).Decode(&commit), "decode commit") + resp.Body.Close() + if commit.Added != 3 || commit.Skipped != 0 { + t.Fatalf("commit: added=%d skipped=%d (want 3/0)", commit.Added, commit.Skipped) + } + + // Re-import the same file — emails should dedup. + var commit2 struct { + Added int `json:"added"` + Skipped int `json:"skipped"` + SkippedEmails []string `json:"skipped_emails"` + } + resp = postMultipart(t, srv.URL+"/events/"+eventID.String()+"/guests/import", token, csvOne) + must(t, json.NewDecoder(resp.Body).Decode(&commit2), "decode commit2") + resp.Body.Close() + // Jordan has no email, so they re-import as a new row each time. + // Alex + Sam have emails and should be skipped. + if commit2.Skipped != 2 { + t.Fatalf("re-import dedup: skipped=%d want 2 (added=%d)", commit2.Skipped, commit2.Added) + } + + // Verify the row count in the DB matches expectations: 3 from first + // commit + 1 (Jordan again) from second. + var count int + must(t, db.Pool.QueryRow(ctx, + "SELECT count(*) FROM guests WHERE event_id = $1", eventID, + ).Scan(&count), "count guests") + if count != 4 { + t.Fatalf("guest count: got %d want 4", count) + } +} + +// TestCsvImportAtomicRollback verifies that a runtime error mid-batch +// leaves NO partial rows. We trigger this by injecting a name longer than +// the column allows (VARCHAR(255)) on row 3 of 4. +func TestCsvImportAtomicRollback(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, _, token := setupAuthedAPI(t, ctx) + eventID := createEvent(t, srv.URL, token, "Atomic Event", "atomic-event") + + longName := bytes.Repeat([]byte("Aa"), 200) // 400 chars > VARCHAR(255) + csv := "name,email\nAlice,a@example.test\nBob,b@example.test\n" + + string(longName) + ",c@example.test\nDave,d@example.test\n" + + resp := postMultipart(t, srv.URL+"/events/"+eventID.String()+"/guests/import", token, csv) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("expected 500 on row insert error, got %d body=%s", resp.StatusCode, body) + } + + var count int + must(t, db.Pool.QueryRow(ctx, + "SELECT count(*) FROM guests WHERE event_id = $1", eventID, + ).Scan(&count), "count after rollback") + if count != 0 { + t.Fatalf("expected 0 guests after rollback, got %d", count) + } +} + +// --- helpers shared with other tests in this dir --- + +// setupAuthedAPI builds a fresh API server + a verified host + bearer +// token. Tests that just need a logged-in host can use this directly. +func setupAuthedAPI(t *testing.T, ctx context.Context) (srv *httptest.Server, db *storage.DB, hostID [16]byte, bearer string) { + t.Helper() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + + var err error + db, err = storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + }) + must(t, err, "build api server") + srv = httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + host := insertHost(t, ctx, db.Pool) + hostID = host + bearer = issueHostToken(t, host) + return +} + +func postMultipart(t *testing.T, url, bearer, body string) *http.Response { + t.Helper() + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + fw, err := mw.CreateFormFile("file", "guests.csv") + must(t, err, "create form file") + _, err = fw.Write([]byte(body)) + must(t, err, "write csv") + must(t, mw.Close(), "close mw") + + req, err := http.NewRequest(http.MethodPost, url, &buf) + must(t, err, "build req") + req.Header.Set("Content-Type", mw.FormDataContentType()) + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } + resp, err := http.DefaultClient.Do(req) + must(t, err, "do multipart") + return resp +} diff --git a/test/integration/e2e_test.go b/test/integration/e2e_test.go index 71df7bf..7fc0092 100644 --- a/test/integration/e2e_test.go +++ b/test/integration/e2e_test.go @@ -25,6 +25,7 @@ import ( "google.golang.org/grpc" "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/auth" "github.com/alchemistkay/guestguard/internal/fraud" pb "github.com/alchemistkay/guestguard/internal/fraudpb" "github.com/alchemistkay/guestguard/internal/natspub" @@ -79,22 +80,32 @@ func TestE2EHappyPath(t *testing.T) { rsvpCounter := subscribeRSVPConfirmed(t, ctx, natsClient) - srv := httptest.NewServer(api.NewServer(api.ServerDeps{ - Logger: logger, - DB: db, - AccessPublisher: natsClient, - RSVPPublisher: natsClient, - FraudScorer: fraudClient, - TokenTTL: 24 * time.Hour, - }).Handler()) + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + FraudScorer: fraudClient, + TokenTTL: 24 * time.Hour, + JWTSecret: "test-secret-must-be-at-least-32-bytes-long-xx", + JWTIssuer: "guestguard-test", + AccessTokenTTL: 15 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + }) + must(t, err, "build api server") + srv := httptest.NewServer(apiSrv.Handler()) t.Cleanup(srv.Close) hostID := insertHost(t, ctx, db.Pool) + hostToken := issueHostToken(t, hostID) t.Run("async access flow flags access_logs", func(t *testing.T) { - eventID := createEvent(t, srv.URL, hostID, "Async Test", "async-test") - guestID := createGuest(t, srv.URL, eventID, "Async Guest") - token := issueToken(t, srv.URL, eventID, guestID) + eventID := createEvent(t, srv.URL, hostToken, "Async Test", "async-test") + guestID := createGuest(t, srv.URL, hostToken, eventID, "Async Guest") + token := issueToken(t, srv.URL, hostToken, eventID, guestID) accessResp := getAccess(t, srv.URL, token) @@ -119,9 +130,9 @@ func TestE2EHappyPath(t *testing.T) { }) t.Run("sync rsvp flow records rsvp and marks token used", func(t *testing.T) { - eventID := createEvent(t, srv.URL, hostID, "Sync Test", "sync-test") - guestID := createGuest(t, srv.URL, eventID, "Sync Guest") - token := issueToken(t, srv.URL, eventID, guestID) + eventID := createEvent(t, srv.URL, hostToken, "Sync Test", "sync-test") + guestID := createGuest(t, srv.URL, hostToken, eventID, "Sync Guest") + token := issueToken(t, srv.URL, hostToken, eventID, guestID) stub.SetNext(15, "low", nil) @@ -145,9 +156,9 @@ func TestE2EHappyPath(t *testing.T) { }) t.Run("sync rsvp flow blocks when fraud score is BLOCK", func(t *testing.T) { - eventID := createEvent(t, srv.URL, hostID, "Block Test", "block-test") - guestID := createGuest(t, srv.URL, eventID, "Block Guest") - token := issueToken(t, srv.URL, eventID, guestID) + eventID := createEvent(t, srv.URL, hostToken, "Block Test", "block-test") + guestID := createGuest(t, srv.URL, hostToken, eventID, "Block Guest") + token := issueToken(t, srv.URL, hostToken, eventID, guestID) stub.SetNext(95, "block", []string{"fingerprint differs from baseline", "ip address changed"}) @@ -274,33 +285,32 @@ func startStubFraudGRPC(t *testing.T) *stubFraud { // --- HTTP helpers --- -func createEvent(t *testing.T, base string, hostID uuid.UUID, name, slug string) uuid.UUID { +func createEvent(t *testing.T, base, accessToken string, name, slug string) uuid.UUID { t.Helper() body := map[string]any{ - "host_id": hostID.String(), "name": name, "slug": slug, "event_date": time.Now().Add(30 * 24 * time.Hour).UTC().Format(time.RFC3339), "venue": "Integration Hall", } var out struct{ ID uuid.UUID `json:"id"` } - postJSON(t, base+"/events", body, http.StatusCreated, &out) + postJSONAuthed(t, base+"/events", accessToken, body, http.StatusCreated, &out) return out.ID } -func createGuest(t *testing.T, base string, eventID uuid.UUID, name string) uuid.UUID { +func createGuest(t *testing.T, base, accessToken string, eventID uuid.UUID, name string) uuid.UUID { t.Helper() var out struct{ ID uuid.UUID `json:"id"` } - postJSON(t, fmt.Sprintf("%s/events/%s/guests", base, eventID), + postJSONAuthed(t, fmt.Sprintf("%s/events/%s/guests", base, eventID), accessToken, map[string]any{"name": name}, http.StatusCreated, &out) return out.ID } -func issueToken(t *testing.T, base string, eventID, guestID uuid.UUID) string { +func issueToken(t *testing.T, base, accessToken string, eventID, guestID uuid.UUID) string { t.Helper() var out struct{ Token string `json:"token"` } - postJSON(t, fmt.Sprintf("%s/events/%s/guests/%s/tokens", base, eventID, guestID), - nil, http.StatusCreated, &out) + postJSONAuthed(t, fmt.Sprintf("%s/events/%s/guests/%s/tokens", base, eventID, guestID), + accessToken, nil, http.StatusCreated, &out) return out.Token } @@ -340,6 +350,11 @@ func submitRSVP(t *testing.T, base, token string, body map[string]any) submitRSV } func postJSON(t *testing.T, url string, body any, wantStatus int, out any) { + t.Helper() + postJSONAuthed(t, url, "", body, wantStatus, out) +} + +func postJSONAuthed(t *testing.T, url, bearer string, body any, wantStatus int, out any) { t.Helper() var rdr io.Reader if body != nil { @@ -351,6 +366,9 @@ func postJSON(t *testing.T, url string, body any, wantStatus int, out any) { if rdr != nil { req.Header.Set("Content-Type", "application/json") } + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } resp, err := http.DefaultClient.Do(req) must(t, err, "do request "+url) defer resp.Body.Close() @@ -370,14 +388,48 @@ func insertHost(t *testing.T, ctx context.Context, pool *pgxpool.Pool) uuid.UUID t.Helper() var id uuid.UUID err := pool.QueryRow(ctx, - `INSERT INTO users (email, name) VALUES ($1, $2) RETURNING id`, + `INSERT INTO users (email, name, email_verified, email_verified_at) + VALUES ($1, $2, TRUE, now()) RETURNING id`, fmt.Sprintf("test-%d@guestguard.test", time.Now().UnixNano()), "Integration Host", ).Scan(&id) must(t, err, "insert host") + // Default test hosts to the Business tier so existing tests that + // create multiple events for one host aren't tripped up by the + // free-tier limit (1 event / month). Tests that specifically exercise + // the free-tier path skip this helper. + grantBusinessTier(t, ctx, pool, id) return id } +// grantBusinessTier inserts an active Business subscription row for the +// given user so tier-enforcement middleware grants unlimited events. +func grantBusinessTier(t *testing.T, ctx context.Context, pool *pgxpool.Pool, userID uuid.UUID) { + t.Helper() + _, err := pool.Exec(ctx, ` + INSERT INTO subscriptions (user_id, stripe_customer_id, tier, status) + VALUES ($1::uuid, 'cus_test_' || replace($1::uuid::text, '-', ''), 'business', 'active') + `, userID.String()) + must(t, err, "grant business tier") +} + +// issueHostToken mints a Bearer access token for an existing host using the +// same JWT secret/issuer the test API server was constructed with. This +// lets integration tests skip the signup/verify/login dance. +func issueHostToken(t *testing.T, hostID uuid.UUID) string { + t.Helper() + signer, err := auth.NewJWTSigner(testJWTSecret, 5*time.Minute, testJWTIssuer) + must(t, err, "build jwt signer") + tok, _, err := signer.Issue(hostID, time.Now()) + must(t, err, "issue jwt") + return tok +} + +const ( + testJWTSecret = "test-secret-must-be-at-least-32-bytes-long-xx" + testJWTIssuer = "guestguard-test" +) + func waitForFlagged(t *testing.T, ctx context.Context, pool *pgxpool.Pool, accessLogID uuid.UUID, wantScore int, wantFlagged bool) { t.Helper() deadline := time.Now().Add(10 * time.Second) diff --git a/test/integration/guest_crud_test.go b/test/integration/guest_crud_test.go new file mode 100644 index 0000000..771b450 --- /dev/null +++ b/test/integration/guest_crud_test.go @@ -0,0 +1,281 @@ +//go:build integration + +package integration_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/natspub" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestGuestUpdate confirms PATCH semantics: partial fields update, empty +// strings clear nullable columns, missing fields are left untouched. +func TestGuestUpdate(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, _, token := setupAuthedAPI(t, ctx) + eventID := createEvent(t, srv.URL, token, "Update Event", "update-event") + guestID, originalEmail := createGuestWithEmail(t, srv.URL, token, eventID, "Original Name") + + // Patch name + email together. + var updated struct { + Name string `json:"name"` + Email *string `json:"email"` + } + patchJSON(t, fmt.Sprintf("%s/events/%s/guests/%s", srv.URL, eventID, guestID), token, + map[string]any{"name": "Renamed", "email": "new-" + originalEmail}, + http.StatusOK, &updated) + if updated.Name != "Renamed" { + t.Errorf("name: got %q want Renamed", updated.Name) + } + if updated.Email == nil || !strings.HasPrefix(*updated.Email, "new-") { + t.Errorf("email: got %v", updated.Email) + } + + // Clear the email by sending empty string. domain.Guest tags Email + // as omitempty so a nil pointer doesn't serialize; check DB state + // directly instead of relying on the response shape. + patchJSON(t, fmt.Sprintf("%s/events/%s/guests/%s", srv.URL, eventID, guestID), token, + map[string]any{"email": ""}, http.StatusOK, nil) + var dbEmail *string + must(t, db.Pool.QueryRow(ctx, "SELECT email FROM guests WHERE id = $1", guestID).Scan(&dbEmail), + "fetch email after clear") + if dbEmail != nil { + t.Errorf("expected DB email cleared (NULL), got %q", *dbEmail) + } + + // Empty name is rejected. + assertStatus(t, http.MethodPatch, fmt.Sprintf("%s/events/%s/guests/%s", srv.URL, eventID, guestID), + token, map[string]any{"name": ""}, http.StatusBadRequest) +} + +// TestGuestDelete confirms the row goes away + cascade-deletes the token. +func TestGuestDelete(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, _, token := setupAuthedAPI(t, ctx) + eventID := createEvent(t, srv.URL, token, "Delete Event", "delete-event") + guestID, _ := createGuestWithEmail(t, srv.URL, token, eventID, "Bye") + + // Give the guest a token first so we can prove the cascade. + issueToken(t, srv.URL, token, eventID, guestID) + var hasToken int + must(t, db.Pool.QueryRow(ctx, "SELECT count(*) FROM tokens WHERE guest_id = $1", guestID).Scan(&hasToken), + "count tokens before delete") + if hasToken != 1 { + t.Fatalf("setup: expected 1 token, got %d", hasToken) + } + + assertStatus(t, http.MethodDelete, fmt.Sprintf("%s/events/%s/guests/%s", srv.URL, eventID, guestID), + token, nil, http.StatusNoContent) + + var remaining int + must(t, db.Pool.QueryRow(ctx, "SELECT count(*) FROM guests WHERE id = $1", guestID).Scan(&remaining), + "count after delete") + if remaining != 0 { + t.Errorf("guest still exists after delete") + } + must(t, db.Pool.QueryRow(ctx, "SELECT count(*) FROM tokens WHERE guest_id = $1", guestID).Scan(&hasToken), + "count tokens after delete") + if hasToken != 0 { + t.Errorf("expected cascade to delete the token, %d still exist", hasToken) + } + + // Re-deleting → 404. + assertStatus(t, http.MethodDelete, fmt.Sprintf("%s/events/%s/guests/%s", srv.URL, eventID, guestID), + token, nil, http.StatusNotFound) +} + +// TestTokenRotate confirms the old token stops working and the new one +// is recognised. Optionally re-publishes invitation.send when send_email +// is true. +func TestTokenRotate(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + dsn := startPostgres(t, ctx) + natsURL := startNATS(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + natsClient, err := natspub.Connect(ctx, natsURL, logger) + must(t, err, "connect nats") + t.Cleanup(natsClient.Close) + + var invitationCount atomic.Int32 + sub, err := natspub.NewInvitationSendSubscriber(ctx, natsClient, "test-rotate", + func(_ context.Context, _ natspub.InvitationSend) error { + invitationCount.Add(1) + return nil + }, logger) + must(t, err, "subscriber") + cc, err := sub.Start(ctx) + must(t, err, "start subscriber") + t.Cleanup(cc.Stop) + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + InvitationPublisher: natsClient, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "https://gg.example.test", + }) + must(t, err, "api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostID := insertHost(t, ctx, db.Pool) + hostToken := issueHostToken(t, hostID) + eventID := createEvent(t, srv.URL, hostToken, "Rotate Event", "rotate-event") + guestID, _ := createGuestWithEmail(t, srv.URL, hostToken, eventID, "Mira") + + // Initial issue — captures the original token + 1 invitation publish. + originalToken := issueToken(t, srv.URL, hostToken, eventID, guestID) + + // The /access endpoint accepts the original token before rotation. + resp, err := http.Get(srv.URL + "/access/" + originalToken) + must(t, err, "GET original access") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("original access pre-rotate: got %d want 200", resp.StatusCode) + } + + // Rotate WITHOUT email — just want a fresh link. + var rotated struct { + Token string `json:"token"` + InvitationQueued bool `json:"invitation_queued"` + } + postJSONAuthed(t, + fmt.Sprintf("%s/events/%s/guests/%s/tokens/rotate", srv.URL, eventID, guestID), + hostToken, + map[string]any{"send_email": false}, + http.StatusOK, &rotated) + if rotated.Token == "" || rotated.Token == originalToken { + t.Fatalf("rotated token should be fresh and non-empty (was %q, original %q)", rotated.Token, originalToken) + } + if rotated.InvitationQueued { + t.Error("expected invitation_queued=false when send_email=false") + } + + // Old token no longer works. + resp, err = http.Get(srv.URL + "/access/" + originalToken) + must(t, err, "GET original after rotate") + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + t.Errorf("old token should not authenticate after rotation, got 200") + } + + // New token works. + resp, err = http.Get(srv.URL + "/access/" + rotated.Token) + must(t, err, "GET new access") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("new token should authenticate: got %d", resp.StatusCode) + } + + // Rotate again WITH email — should bump the invitation count. + preCount := invitationCount.Load() + var rotated2 struct { + Token string `json:"token"` + InvitationQueued bool `json:"invitation_queued"` + } + postJSONAuthed(t, + fmt.Sprintf("%s/events/%s/guests/%s/tokens/rotate", srv.URL, eventID, guestID), + hostToken, + map[string]any{"send_email": true}, + http.StatusOK, &rotated2) + if !rotated2.InvitationQueued { + t.Error("expected invitation_queued=true when send_email=true") + } + // Wait briefly for the NATS round-trip. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if invitationCount.Load() > preCount { + break + } + time.Sleep(100 * time.Millisecond) + } + if invitationCount.Load() <= preCount { + t.Fatalf("expected invitation publish after rotate-with-send, count %d -> %d", preCount, invitationCount.Load()) + } +} + +func patchJSONRaw(t *testing.T, url, bearer string, body any, wantStatus int) []byte { + t.Helper() + b, _ := json.Marshal(body) + req, err := http.NewRequest(http.MethodPatch, url, strings.NewReader(string(b))) + must(t, err, "build patch") + req.Header.Set("Content-Type", "application/json") + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } + resp, err := http.DefaultClient.Do(req) + must(t, err, "do patch") + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != wantStatus { + t.Fatalf("%s status=%d want=%d body=%s", url, resp.StatusCode, wantStatus, respBody) + } + return respBody +} + +func patchJSON(t *testing.T, url, bearer string, body any, wantStatus int, out any) { + t.Helper() + b, _ := json.Marshal(body) + req, err := http.NewRequest(http.MethodPatch, url, strings.NewReader(string(b))) + must(t, err, "build patch") + req.Header.Set("Content-Type", "application/json") + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } + resp, err := http.DefaultClient.Do(req) + must(t, err, "do patch") + defer resp.Body.Close() + if resp.StatusCode != wantStatus { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("%s status=%d want=%d body=%s", url, resp.StatusCode, wantStatus, body) + } + if out != nil { + must(t, json.NewDecoder(resp.Body).Decode(out), "decode patch response") + } +} + +// Silence unused import warning if the future drops something. +var _ uuid.UUID \ No newline at end of file diff --git a/test/integration/invitation_send_test.go b/test/integration/invitation_send_test.go new file mode 100644 index 0000000..1cf8c4b --- /dev/null +++ b/test/integration/invitation_send_test.go @@ -0,0 +1,218 @@ +//go:build integration + +package integration_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/natspub" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestTokenIssuePublishesInvitation walks the new wire-up: issuing a +// token for a guest with an email-on-file should publish an +// invitation.send event over NATS, and the API response should reflect +// that the invitation was queued. +func TestTokenIssuePublishesInvitation(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + dsn := startPostgres(t, ctx) + natsURL := startNATS(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + natsClient, err := natspub.Connect(ctx, natsURL, logger) + must(t, err, "connect nats") + t.Cleanup(natsClient.Close) + + // Subscribe before issuing — JetStream replay is on by default but + // this guarantees we don't race the consumer setup. + var seen atomic.Int32 + captured := make(chan natspub.InvitationSend, 1) + sub, err := natspub.NewInvitationSendSubscriber(ctx, natsClient, "test-invitation-send", + func(ctx context.Context, evt natspub.InvitationSend) error { + seen.Add(1) + select { + case captured <- evt: + default: + } + return nil + }, logger) + must(t, err, "build subscriber") + cc, err := sub.Start(ctx) + must(t, err, "start subscriber") + t.Cleanup(cc.Stop) + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + InvitationPublisher: natsClient, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "https://gg.example.test", + }) + must(t, err, "build api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostID := insertHost(t, ctx, db.Pool) + token := issueHostToken(t, hostID) + + eventID := createEvent(t, srv.URL, token, "Invitation Test", "invitation-test") + guestID, guestEmail := createGuestWithEmail(t, srv.URL, token, eventID, "Mira") + + var issued struct { + Token string `json:"token"` + InvitationQueued bool `json:"invitation_queued"` + InvitationLink string `json:"invitation_link"` + } + postJSONAuthed(t, + fmt.Sprintf("%s/events/%s/guests/%s/tokens", srv.URL, eventID, guestID), + token, nil, http.StatusCreated, &issued) + + if !issued.InvitationQueued { + t.Fatalf("expected invitation_queued=true (response: %+v)", issued) + } + if !strings.HasPrefix(issued.InvitationLink, "https://gg.example.test/rsvp/tk_") { + t.Fatalf("invitation_link should use publicBaseURL: got %q", issued.InvitationLink) + } + + select { + case evt := <-captured: + if evt.GuestID.String() != guestID.String() { + t.Errorf("guest id: got %s want %s", evt.GuestID, guestID) + } + if evt.GuestEmail != guestEmail { + t.Errorf("guest email: got %s want %s", evt.GuestEmail, guestEmail) + } + if evt.EventName != "Invitation Test" { + t.Errorf("event name: got %s", evt.EventName) + } + if !strings.HasPrefix(evt.Link, "https://gg.example.test/rsvp/tk_") { + t.Errorf("link: got %s", evt.Link) + } + case <-time.After(10 * time.Second): + t.Fatalf("did not see invitation.send within 10s (seen=%d)", seen.Load()) + } +} + +// TestTokenIssueWithoutGuestEmailSkipsInvitation confirms that a guest +// with no email on file does NOT trigger a publish — the host still gets +// a copy-pasteable link. +func TestTokenIssueWithoutGuestEmailSkipsInvitation(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + dsn := startPostgres(t, ctx) + natsURL := startNATS(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + natsClient, err := natspub.Connect(ctx, natsURL, logger) + must(t, err, "connect nats") + t.Cleanup(natsClient.Close) + + var invitations atomic.Int32 + sub, err := natspub.NewInvitationSendSubscriber(ctx, natsClient, "test-no-email-invitation", + func(ctx context.Context, evt natspub.InvitationSend) error { + invitations.Add(1) + return nil + }, logger) + must(t, err, "build subscriber") + cc, err := sub.Start(ctx) + must(t, err, "start subscriber") + t.Cleanup(cc.Stop) + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + AccessPublisher: natsClient, + RSVPPublisher: natsClient, + InvitationPublisher: natsClient, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + }) + must(t, err, "build api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + + hostID := insertHost(t, ctx, db.Pool) + hostToken := issueHostToken(t, hostID) + + eventID := createEvent(t, srv.URL, hostToken, "No Email Event", "no-email-event") + // createGuest helper produces a guest with no email field set. + guestID := createGuest(t, srv.URL, hostToken, eventID, "Phone-only Guest") + + var issued struct { + InvitationQueued bool `json:"invitation_queued"` + } + postJSONAuthed(t, + fmt.Sprintf("%s/events/%s/guests/%s/tokens", srv.URL, eventID, guestID), + hostToken, nil, http.StatusCreated, &issued) + if issued.InvitationQueued { + t.Fatalf("expected invitation_queued=false for emailless guest") + } + + // Give NATS a moment to surface any (unwanted) message. + time.Sleep(500 * time.Millisecond) + if invitations.Load() != 0 { + t.Fatalf("expected 0 invitation.send messages, got %d", invitations.Load()) + } +} + +// createGuestWithEmail is a thin wrapper that adds an email field, since +// the existing helper omits it. +func createGuestWithEmail(t *testing.T, base, accessToken string, eventID uuid.UUID, name string) (uuid.UUID, string) { + t.Helper() + email := fmt.Sprintf("guest-%d@example.test", time.Now().UnixNano()) + body := map[string]any{"name": name, "email": email} + var out struct{ ID uuid.UUID `json:"id"` } + postJSONAuthed(t, fmt.Sprintf("%s/events/%s/guests", base, eventID), accessToken, + body, http.StatusCreated, &out) + return out.ID, email +} + +// Avoid the import being marked unused if a future refactor drops it. +var _ = json.Marshal diff --git a/test/integration/mailpit_test.go b/test/integration/mailpit_test.go new file mode 100644 index 0000000..3c2e5e4 --- /dev/null +++ b/test/integration/mailpit_test.go @@ -0,0 +1,122 @@ +//go:build integration + +package integration_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/alchemistkay/guestguard/internal/notification" +) + +// startMailpit launches a Mailpit container and returns (smtpHost, smtpPort, +// httpBaseURL). The HTTP API is what we query to assert delivery. +func startMailpit(t *testing.T, ctx context.Context) (string, int, string) { + t.Helper() + req := testcontainers.ContainerRequest{ + Image: "axllent/mailpit:latest", + ExposedPorts: []string{"1025/tcp", "8025/tcp"}, + WaitingFor: wait.ForListeningPort("8025/tcp").WithStartupTimeout(45 * time.Second), + } + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + must(t, err, "start mailpit container") + t.Cleanup(func() { _ = c.Terminate(context.Background()) }) + + host, err := c.Host(ctx) + must(t, err, "mailpit host") + smtpMP, err := c.MappedPort(ctx, "1025/tcp") + must(t, err, "mailpit smtp port") + httpMP, err := c.MappedPort(ctx, "8025/tcp") + must(t, err, "mailpit http port") + port, _ := strconv.Atoi(smtpMP.Port()) + return host, port, "http://" + host + ":" + httpMP.Port() +} + +// TestSMTPSenderAgainstMailpit sends a real verification email via the +// SMTP adapter and asserts the Mailpit HTTP API saw it land in the inbox. +// This is the closest thing to "did a real email arrive" we can run in CI. +func TestSMTPSenderAgainstMailpit(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + + smtpHost, smtpPort, httpBase := startMailpit(t, ctx) + + tpls, err := notification.NewTemplates() + must(t, err, "templates") + + sender, err := notification.NewSMTPEmailSender(notification.SMTPConfig{ + Host: smtpHost, + Port: smtpPort, + FromEmail: "noreply@guestguard.local", + FromName: "GuestGuard (dev)", + TLS: "none", // mailpit accepts plain SMTP + }, tpls) + must(t, err, "smtp sender") + + must(t, sender.SendVerification(ctx, "kay@example.test", "Kay", + "http://localhost:3000/verify-email?token=demo"), "send") + + // Mailpit exposes a /api/v1/messages list endpoint. Poll briefly since + // SMTP delivery is async. + var found bool + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + resp, err := http.Get(httpBase + "/api/v1/messages") + must(t, err, "mailpit list") + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + var list struct { + Messages []struct { + Subject string `json:"Subject"` + To []struct { + Address string `json:"Address"` + } `json:"To"` + ID string `json:"ID"` + } `json:"messages"` + } + if err := json.Unmarshal(body, &list); err != nil { + t.Fatalf("decode mailpit list: %v body=%s", err, body) + } + for _, m := range list.Messages { + if m.Subject == "Verify your GuestGuard email" { + for _, to := range m.To { + if to.Address == "kay@example.test" { + found = true + // Fetch the full message and confirm the verification + // link survived through MIME encoding. + full, err := http.Get(httpBase + "/api/v1/message/" + m.ID) + must(t, err, "fetch message") + b, _ := io.ReadAll(full.Body) + full.Body.Close() + if !strings.Contains(string(b), "verify-email?token=demo") { + t.Errorf("verification link missing from body: %s", b) + } + break + } + } + } + } + if found { + break + } + time.Sleep(200 * time.Millisecond) + } + if !found { + t.Fatalf("did not see verification email in mailpit within 10s") + } +} diff --git a/test/integration/migration_roundtrip_test.go b/test/integration/migration_roundtrip_test.go new file mode 100644 index 0000000..dabed3e --- /dev/null +++ b/test/integration/migration_roundtrip_test.go @@ -0,0 +1,128 @@ +//go:build integration + +package integration_test + +import ( + "context" + "os" + "path/filepath" + "sort" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" +) + +const migrationsDir = "../../internal/storage/migrations" + +// TestMigrationRoundtrip applies every up migration, runs every down in +// reverse, then applies the ups again, against a fresh Postgres +// container. Catches any down.sql that's missing, broken, or asymmetric +// with its up — Block G's "every migration has a tested down" check. +func TestMigrationRoundtrip(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + dsn := startPostgres(t, ctx) + pool, err := pgxpool.New(ctx, dsn) + must(t, err, "connect") + t.Cleanup(pool.Close) + + ups, downs := loadMigrations(t) + + // Phase 1: apply all ups in order. Mirrors what the API does at boot. + for _, m := range ups { + t.Logf("up: %s", m.version) + execAll(t, ctx, pool, m.sql) + } + // Sanity: the latest table from each migration exists. + for _, expected := range []string{ + "users", "rsvps", "refresh_tokens", "unsubscribes", "subscriptions", + } { + var exists bool + must(t, pool.QueryRow(ctx, + `SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name=$1)`, + expected, + ).Scan(&exists), "check "+expected) + if !exists { + t.Fatalf("after ups: table %q is missing", expected) + } + } + + // Phase 2: apply downs in REVERSE order. Each must execute without + // error (even though the result is allowed to be lossy — down + // migrations are not required to preserve data). + for i := len(downs) - 1; i >= 0; i-- { + m := downs[i] + t.Logf("down: %s", m.version) + execAll(t, ctx, pool, m.sql) + } + // All app tables should be gone now. + for _, gone := range []string{ + "users", "events", "guests", "tokens", "rsvps", + "access_logs", "notifications", "subscriptions", + } { + var exists bool + must(t, pool.QueryRow(ctx, + `SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name=$1)`, + gone, + ).Scan(&exists), "check "+gone) + if exists { + t.Errorf("after downs: %q still exists — incomplete down migration", gone) + } + } + + // Phase 3: re-apply all ups. This catches down migrations that + // leave hidden state (types, sequences, indexes, extensions) which + // would clash on the second up. + for _, m := range ups { + t.Logf("up2: %s", m.version) + execAll(t, ctx, pool, m.sql) + } +} + +type migration struct { + version string + sql string +} + +func loadMigrations(t *testing.T) (ups, downs []migration) { + t.Helper() + entries, err := os.ReadDir(migrationsDir) + must(t, err, "read migrations dir") + for _, e := range entries { + name := e.Name() + b, err := os.ReadFile(filepath.Join(migrationsDir, name)) + must(t, err, "read "+name) + m := migration{sql: string(b)} + switch { + case strings.HasSuffix(name, ".up.sql"): + m.version = strings.TrimSuffix(name, ".up.sql") + ups = append(ups, m) + case strings.HasSuffix(name, ".down.sql"): + m.version = strings.TrimSuffix(name, ".down.sql") + downs = append(downs, m) + } + } + sort.Slice(ups, func(i, j int) bool { return ups[i].version < ups[j].version }) + sort.Slice(downs, func(i, j int) bool { return downs[i].version < downs[j].version }) + if len(ups) != len(downs) { + t.Fatalf("up/down count mismatch: %d ups, %d downs — every migration needs a .down.sql", len(ups), len(downs)) + } + for i := range ups { + if ups[i].version != downs[i].version { + t.Fatalf("migration %s has no matching down (or vice versa)", ups[i].version) + } + } + return ups, downs +} + +func execAll(t *testing.T, ctx context.Context, pool *pgxpool.Pool, sql string) { + t.Helper() + _, err := pool.Exec(ctx, sql) + must(t, err, "exec migration") +} diff --git a/test/integration/notifications_test.go b/test/integration/notifications_test.go new file mode 100644 index 0000000..f4bac36 --- /dev/null +++ b/test/integration/notifications_test.go @@ -0,0 +1,230 @@ +//go:build integration + +package integration_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/notification" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// mustInsertEventAndGuest seeds the bare minimum rows the notifications +// webhook tests need to attach a notification to a real guest. +func mustInsertEventAndGuest(t *testing.T, ctx context.Context, db *storage.DB, hostID uuid.UUID) (uuid.UUID, uuid.UUID) { + t.Helper() + var eventID uuid.UUID + must(t, db.Pool.QueryRow(ctx, ` + INSERT INTO events (host_id, name, slug, event_date) + VALUES ($1, 'Notif Test', $2, now() + interval '30 day') + RETURNING id + `, hostID, fmt.Sprintf("notif-%d", time.Now().UnixNano())).Scan(&eventID), + "insert event") + var guestID uuid.UUID + must(t, db.Pool.QueryRow(ctx, ` + INSERT INTO guests (event_id, name, email) + VALUES ($1, 'Notif Guest', $2) + RETURNING id + `, eventID, fmt.Sprintf("notif-%d@example.test", time.Now().UnixNano())).Scan(&guestID), + "insert guest") + return eventID, guestID +} + +func setupNotificationsAPI(t *testing.T, ctx context.Context) (*httptest.Server, *storage.DB, *notification.UnsubscribeSigner) { + t.Helper() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + suppressions := notification.NewSuppressionRepo(db) + notifRepo := notification.NewRepo(db) + const secret = "test-unsubscribe-secret-at-least-32-bytes-long" + signer := notification.NewUnsubscribeSigner(secret) + + apiSrv, err := api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + NotificationRepo: notifRepo, + SuppressionRepo: suppressions, + UnsubscribeSigner: signer, + }) + must(t, err, "build api server") + srv := httptest.NewServer(apiSrv.Handler()) + t.Cleanup(srv.Close) + return srv, db, signer +} + +// TestUnsubscribeFlow exercises the signed-link end-to-end: preview surfaces +// the email, confirm writes the suppression row, and a tampered token is +// rejected. +func TestUnsubscribeFlow(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, signer := setupNotificationsAPI(t, ctx) + + email := "mira@example.test" + token := signer.Sign(email) + + // Preview returns the bound email. + var preview struct{ Email string } + getJSONAuthed(t, srv.URL+"/unsubscribe/"+token, "", http.StatusOK, &preview) + if preview.Email != email { + t.Fatalf("preview email: got %q want %q", preview.Email, email) + } + + // Confirm writes the row. + assertStatus(t, http.MethodPost, srv.URL+"/unsubscribe/"+token, "", nil, http.StatusOK) + + yep, err := notification.NewSuppressionRepo(db).IsSuppressed(ctx, email) + must(t, err, "check suppression") + if !yep { + t.Fatalf("expected email %s suppressed", email) + } + + // Tampered token is rejected. + tampered := token[:len(token)-2] + "xx" + assertStatus(t, http.MethodGet, srv.URL+"/unsubscribe/"+tampered, "", nil, http.StatusBadRequest) +} + +// TestSESBounceWebhook walks the inbound bounce → suppression chain. We +// build a notification row first (so MarkBounce has something to update), +// then post a Bounce envelope, then verify both the status flip and the +// suppression entry. +func TestSESBounceWebhook(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + srv, db, _ := setupNotificationsAPI(t, ctx) + + // Insert a fake guest + notification with a known provider_message_id. + hostID := insertHost(t, ctx, db.Pool) + var eventID, guestID = mustInsertEventAndGuest(t, ctx, db, hostID) + const msgID = "ses-fake-message-id-1234" + notifRepo := notification.NewRepo(db) + _, err := notifRepo.Record(ctx, notification.RecordParams{ + GuestID: guestID, + Channel: notification.ChannelEmail, + Type: notification.TypeInvitation, + Status: notification.StatusSent, + ProviderMessageID: msgID, + }) + must(t, err, "seed notification") + _ = eventID + + // SES → SNS envelope: outer "Notification" carries inner JSON as a string. + innerJSON, _ := json.Marshal(map[string]any{ + "notificationType": "Bounce", + "mail": map[string]any{"messageId": msgID}, + "bounce": map[string]any{ + "bounceType": "Permanent", + "bouncedRecipients": []map[string]any{ + {"emailAddress": "bouncer@example.test"}, + }, + }, + }) + envelope, _ := json.Marshal(map[string]any{ + "Type": "Notification", + "Message": string(innerJSON), + }) + + req, _ := http.NewRequest(http.MethodPost, srv.URL+"/webhooks/ses/notifications", + strings.NewReader(string(envelope))) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + must(t, err, "post ses webhook") + resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected 204, got %d", resp.StatusCode) + } + + // Notification row marked as bounced/permanent. + var status, bounceType string + must(t, db.Pool.QueryRow(ctx, + "SELECT status, bounce_type FROM notifications WHERE provider_message_id = $1", + msgID, + ).Scan(&status, &bounceType), "fetch notification") + if status != "bounced" || bounceType != "permanent" { + t.Fatalf("bad row: status=%s bounce_type=%s", status, bounceType) + } + + // Suppression row populated. + yep, err := notification.NewSuppressionRepo(db).IsSuppressed(ctx, "bouncer@example.test") + must(t, err, "check suppression") + if !yep { + t.Fatal("expected bouncer email suppressed") + } +} + +// TestTwilioStatusWebhook flips a row's status to delivered. +func TestTwilioStatusWebhook(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + srv, db, _ := setupNotificationsAPI(t, ctx) + + hostID := insertHost(t, ctx, db.Pool) + _, guestID := mustInsertEventAndGuest(t, ctx, db, hostID) + const sid = "SMfake0123456789" + _, err := notification.NewRepo(db).Record(ctx, notification.RecordParams{ + GuestID: guestID, + Channel: notification.ChannelSMS, + Type: notification.TypeInvitation, + Status: notification.StatusSent, + ProviderMessageID: sid, + }) + must(t, err, "seed notification") + + form := url.Values{} + form.Set("MessageSid", sid) + form.Set("MessageStatus", "delivered") + req, _ := http.NewRequest(http.MethodPost, srv.URL+"/webhooks/twilio/status", + strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := http.DefaultClient.Do(req) + must(t, err, "post twilio webhook") + resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("expected 204, got %d", resp.StatusCode) + } + + var status string + must(t, db.Pool.QueryRow(ctx, + "SELECT status FROM notifications WHERE provider_message_id = $1", sid, + ).Scan(&status), "fetch status") + if status != "delivered" { + t.Fatalf("expected delivered, got %s", status) + } +} diff --git a/test/integration/privacy_test.go b/test/integration/privacy_test.go new file mode 100644 index 0000000..9cf39cd --- /dev/null +++ b/test/integration/privacy_test.go @@ -0,0 +1,213 @@ +//go:build integration + +package integration_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/google/uuid" +) + +// TestDataExport confirms GET /me/data-export returns a JSON payload +// containing the user + their events + nested records, and rejects +// unauthenticated callers. +func TestDataExport(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, _, hostID, token := setupAuthedAPI(t, ctx) + eventID := createEvent(t, srv.URL, token, "Export Event", "export-event") + guestID, _ := createGuestWithEmail(t, srv.URL, token, eventID, "Export Guest") + _ = issueToken(t, srv.URL, token, eventID, guestID) + + // Unauthenticated → 401. + resp, err := http.Get(srv.URL + "/me/data-export") + must(t, err, "GET unauthed") + resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("unauthed export should 401, got %d", resp.StatusCode) + } + + // Authed → JSON dump with the expected shape. + req, _ := http.NewRequest(http.MethodGet, srv.URL+"/me/data-export", nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err = http.DefaultClient.Do(req) + must(t, err, "GET authed") + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("authed export status=%d body=%s", resp.StatusCode, body) + } + if cd := resp.Header.Get("Content-Disposition"); cd == "" { + t.Errorf("missing Content-Disposition header — browser won't offer download") + } + body, _ := io.ReadAll(resp.Body) + + var out struct { + Format string `json:"format"` + User struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + } `json:"user"` + Events []struct { + ID uuid.UUID `json:"id"` + } `json:"events"` + Guests []struct { + ID uuid.UUID `json:"id"` + } `json:"guests"` + Tokens []struct { + ID uuid.UUID `json:"id"` + } `json:"tokens"` + } + must(t, json.Unmarshal(body, &out), "decode export") + + if out.Format != "guestguard.v1" { + t.Errorf("format: got %q want guestguard.v1", out.Format) + } + if out.User.ID.String() != uuid.UUID(hostID).String() { + t.Errorf("user id mismatch: got %s want %s", out.User.ID, uuid.UUID(hostID)) + } + if len(out.Events) != 1 || out.Events[0].ID != eventID { + t.Errorf("events: got %d entries, want 1 for the seeded event", len(out.Events)) + } + if len(out.Guests) != 1 || out.Guests[0].ID != guestID { + t.Errorf("guests: got %d entries, want 1", len(out.Guests)) + } + if len(out.Tokens) != 1 { + t.Errorf("tokens: got %d entries, want 1 (issued one above)", len(out.Tokens)) + } +} + +// TestDeleteMe walks the soft-delete flow: account row gets tombstoned, +// subsequent /me requests fail because the user is no longer findable, +// and re-signup with the same email succeeds (proves the unique index +// only constrains live users). +func TestDeleteMe(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, hostID, token := setupAuthedAPI(t, ctx) + // Capture the original email so we can prove the tombstone scrubbed it. + var originalEmail string + must(t, db.Pool.QueryRow(ctx, `SELECT email FROM users WHERE id = $1`, uuid.UUID(hostID)).Scan(&originalEmail), + "fetch original email") + + // Hit DELETE /me. + req, _ := http.NewRequest(http.MethodDelete, srv.URL+"/me", nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + must(t, err, "DELETE /me") + resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("DELETE /me status=%d want 204", resp.StatusCode) + } + + // DB: row is soft-deleted with PII scrubbed. + var deletedAt *time.Time + var emailAfter, nameAfter string + must(t, db.Pool.QueryRow(ctx, + `SELECT deleted_at, email, name FROM users WHERE id = $1`, uuid.UUID(hostID), + ).Scan(&deletedAt, &emailAfter, &nameAfter), "fetch after delete") + if deletedAt == nil { + t.Fatal("expected deleted_at set") + } + if emailAfter == originalEmail { + t.Errorf("expected email scrubbed, got %q", emailAfter) + } + if nameAfter != "Deleted user" { + t.Errorf("expected name='Deleted user', got %q", nameAfter) + } + + // API: subsequent /me with the same JWT returns 401 (user not found + // by GetByID since it filters on deleted_at). + req, _ = http.NewRequest(http.MethodGet, srv.URL+"/me", nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err = http.DefaultClient.Do(req) + must(t, err, "GET /me after delete") + resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("post-delete /me should 401, got %d", resp.StatusCode) + } + + // Re-signup with the ORIGINAL email succeeds because the unique + // index is partial — soft-deleted rows don't block new accounts. + body := fmt.Sprintf(`{"email":%q,"name":"New owner","password":"correct-horse","accept_terms":true}`, originalEmail) + resp, err = http.Post(srv.URL+"/auth/signup", "application/json", stringReader(body)) + must(t, err, "POST /auth/signup after delete") + resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("re-signup status=%d (expected 201 — soft-deleted shouldn't block new signups)", resp.StatusCode) + } +} + +// TestAcceptTerms confirms a user created without terms acceptance can +// record it post-hoc via POST /me/accept-terms. +func TestAcceptTerms(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + srv, db, hostID, token := setupAuthedAPI(t, ctx) + + // insertHost (in e2e_test.go) doesn't set the terms columns, so a + // fresh host starts with NULLs. Confirm. + var acceptedAt *time.Time + must(t, db.Pool.QueryRow(ctx, + `SELECT terms_accepted_at FROM users WHERE id = $1`, uuid.UUID(hostID), + ).Scan(&acceptedAt), "fetch pre") + if acceptedAt != nil { + t.Fatal("setup: expected fresh host to have no terms_accepted_at") + } + + // Hit the endpoint. + req, _ := http.NewRequest(http.MethodPost, srv.URL+"/me/accept-terms", nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + must(t, err, "POST /me/accept-terms") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("accept-terms status=%d", resp.StatusCode) + } + + // DB: timestamps now set on both columns. + var terms, privacy *time.Time + must(t, db.Pool.QueryRow(ctx, + `SELECT terms_accepted_at, privacy_policy_accepted_at FROM users WHERE id = $1`, uuid.UUID(hostID), + ).Scan(&terms, &privacy), "fetch post") + if terms == nil || privacy == nil { + t.Errorf("expected both timestamps set: terms=%v privacy=%v", terms, privacy) + } +} + +func stringReader(s string) *stringReadCloser { + return &stringReadCloser{s: s} +} + +type stringReadCloser struct { + s string + pos int +} + +func (r *stringReadCloser) Read(p []byte) (int, error) { + if r.pos >= len(r.s) { + return 0, io.EOF + } + n := copy(p, r.s[r.pos:]) + r.pos += n + return n, nil +} +func (r *stringReadCloser) Close() error { return nil } diff --git a/test/integration/ratelimit_test.go b/test/integration/ratelimit_test.go new file mode 100644 index 0000000..7ec8a6d --- /dev/null +++ b/test/integration/ratelimit_test.go @@ -0,0 +1,205 @@ +//go:build integration + +package integration_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + + "github.com/alchemistkay/guestguard/internal/api" + "github.com/alchemistkay/guestguard/internal/storage" +) + +// TestRateLimitSignup confirms the per-IP signup limit kicks in at the +// configured threshold and includes a Retry-After header. +func TestRateLimitSignup(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + mr, err := miniredis.Run() + must(t, err, "miniredis") + t.Cleanup(mr.Close) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { _ = rdb.Close() }) + + srv := httptest.NewServer(must1(api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + Redis: rdb, + }))(t).Handler()) + t.Cleanup(srv.Close) + + // First 5 signups from the same IP succeed (httptest reuses 127.0.0.1). + for i := 0; i < 5; i++ { + body := map[string]any{ + "email": uniqueEmail(t), + "name": "Probe", + "password": "correct-horse", + } + postJSONAuthed(t, srv.URL+"/auth/signup", "", body, http.StatusCreated, nil) + } + + // 6th in the same window must be 429. + req := buildJSON(t, http.MethodPost, srv.URL+"/auth/signup", "", + map[string]any{"email": uniqueEmail(t), "name": "Probe", "password": "correct-horse"}) + resp, err := http.DefaultClient.Do(req) + must(t, err, "do 6th signup") + defer resp.Body.Close() + if resp.StatusCode != http.StatusTooManyRequests { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 429, got %d body=%s", resp.StatusCode, body) + } + if resp.Header.Get("Retry-After") == "" { + t.Fatal("missing Retry-After header on 429") + } +} + +// TestLoginLockout confirms 5 consecutive bad-password attempts trip the +// account-lock flag, login then 403s with a "locked" message, and only a +// password reset clears it. +func TestLoginLockout(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in -short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + t.Cleanup(cancel) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + dsn := startPostgres(t, ctx) + + db, err := storage.NewDB(ctx, dsn) + must(t, err, "connect db") + t.Cleanup(db.Close) + must(t, db.Migrate(ctx), "migrate") + + mr, err := miniredis.Run() + must(t, err, "miniredis") + t.Cleanup(mr.Close) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { _ = rdb.Close() }) + + emails := &recordingEmailSender{} + srv := httptest.NewServer(must1(api.NewServer(api.ServerDeps{ + Logger: logger, + DB: db, + TokenTTL: 24 * time.Hour, + JWTSecret: testJWTSecret, + JWTIssuer: testJWTIssuer, + AccessTokenTTL: 5 * time.Minute, + RefreshTokenTTL: 24 * time.Hour, + EmailVerificationTTL: 1 * time.Hour, + PasswordResetTTL: 1 * time.Hour, + PublicBaseURL: "http://localhost", + EmailSender: emails, + Redis: rdb, + }))(t).Handler()) + t.Cleanup(srv.Close) + + email := "lockout-" + uuid.NewString() + "@guestguard.test" + + // Sign up and verify so we have a working account. + postJSONAuthed(t, srv.URL+"/auth/signup", "", + map[string]any{"email": email, "name": "Lock Probe", "password": "correct-horse"}, + http.StatusCreated, nil) + token := tokenFromQuery(t, emails.verifyLink, "token") + postJSONAuthed(t, srv.URL+"/auth/verify-email", "", + map[string]any{"token": token}, http.StatusOK, nil) + + // 5 bad-password attempts — the lockout middleware engages at the 5th. + for i := 0; i < 5; i++ { + req := buildJSON(t, http.MethodPost, srv.URL+"/auth/login", "", + map[string]any{"email": email, "password": "wrong"}) + resp, err := http.DefaultClient.Do(req) + must(t, err, "do login") + resp.Body.Close() + } + + // Even correct password is now rejected with 403 "locked". + { + req := buildJSON(t, http.MethodPost, srv.URL+"/auth/login", "", + map[string]any{"email": email, "password": "correct-horse"}) + resp, err := http.DefaultClient.Do(req) + must(t, err, "do correct login") + defer resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 403 locked, got %d body=%s", resp.StatusCode, body) + } + } + + // Reset the password — should clear the lock. + emails.resetLink = "" + postJSONAuthed(t, srv.URL+"/auth/forgot-password", "", + map[string]any{"email": email}, http.StatusAccepted, nil) + if emails.resetLink == "" { + t.Fatal("reset link not captured") + } + resetToken := tokenFromPath(t, emails.resetLink, "/reset-password/") + postJSONAuthed(t, srv.URL+"/auth/reset-password", "", + map[string]any{"token": resetToken, "new_password": "new-correct-horse"}, + http.StatusOK, nil) + + // Now login succeeds with the new password. + postJSONAuthed(t, srv.URL+"/auth/login", "", + map[string]any{"email": email, "password": "new-correct-horse"}, + http.StatusOK, nil) +} + +func buildJSON(t *testing.T, method, url, bearer string, body any) *http.Request { + t.Helper() + var r io.Reader + if body != nil { + b, err := json.Marshal(body) + must(t, err, "marshal body") + r = bytes.NewReader(b) + } + req, err := http.NewRequest(method, url, r) + must(t, err, "build req") + if r != nil { + req.Header.Set("Content-Type", "application/json") + } + if bearer != "" { + req.Header.Set("Authorization", "Bearer "+bearer) + } + return req +} + +func must1[T any](v T, err error) func(*testing.T) T { + return func(t *testing.T) T { + t.Helper() + must(t, err, "api server") + return v + } +}