feat: ship Tier 1 — auth, authz, rate limits, real notifications, CSV import, billing, backups/DR, privacy

Closes every block in docs/TIER1_PLAN.md from the Claude-scope side. The
homelab / cloud setup steps (SES verification, restore drill, lawyer-
drafted ToS) remain operator-owned but are unblocked.

Block A — Authentication
- Migration 0003: password_hash, email_verified, email_verification_tokens,
  password_reset_tokens, refresh_tokens (with replaced_by family chain).
- Bcrypt hasher, HS256 JWT signer, single-use refresh tokens with rotation
  + replay-detection (revokes the family on reuse).
- /auth/signup, /login, /refresh, /logout, /verify-email,
  /forgot-password, /reset-password — enumeration-safe.
- requireAuth middleware + GET /me.
- Frontend useAuth/useApi with auto-refresh-on-401, login/signup/verify/
  forgot/reset pages, route-guard middleware.

Block B — Authorisation
- EventRepo.GetForHost; Update/Delete scoped by host_id.
- All host routes behind requireAuth + ownership; cross-tenant returns
  404 (no enumeration). ?host_id removed.
- WS auth via short-lived single-use tickets (POST /auth/ws-ticket).
- Tests: TestCrossTenantIsolation — 9 probes.

Block C — Rate limiting
- Redis sliding-window via Lua (atomic ZADD+ZCARD+PEXPIRE).
- Per-route limits matching the plan (signup IP, login IP+email, RSVP/
  access by token, events/guests/tokens by user_id).
- 429 with Retry-After header and JSON body.
- Auth lockout: 5 failed logins → account locked, only password reset
  clears it.
- Frontend: useErrMessage normalises 429 + locked messaging.

Block D — Real notifications
- Migration 0004: provider_message_id, bounce_type, complained columns
  + unsubscribes (CITEXT) suppression table.
- Branded HTML + plaintext templates for verification, reset, invitation,
  confirmation, reminder. Per-page templates avoid html/template's
  contextual-escape collisions.
- Senders: SESv2, Twilio (SMS), SMTP (Mailpit-friendly), Resend HTTP.
- PickEmailSender priority Resend > SMTP > SES > Log — system boots
  cleanly in dev with Mailpit; production flips one env var.
- Webhook endpoints (Twilio status + SES SNS) — bounces add to suppression;
  signature verification stubbed pending creds.
- Auto-send: POST /tokens publishes invitation.send; notifier renders +
  delivers via the configured backend; suppression list honoured.
- Bulk + per-row invitation flow: POST /events/{id}/guests/invitations/bulk
  returns per-guest tokens so phone-only guests can be SMS'd manually.
- Unsubscribe: signed HMAC token (no TTL) + /unsubscribe/[token] page.
- WhatsApp Option A+: wa.me click-to-chat wizard with per-guest progress
  tracking, isLikelyE164 validation, edit-from-wizard.
- Token rotate (POST /tokens/rotate) invalidates the old URL — used by
  the regenerate-link flow.
- Mailpit added to docker-compose for dev inbox.

Block E — CSV import
- Streaming parser: tolerant header detection, UTF-8 BOM + UTF-16 LE/BE
  decoding, row-level validation, 5,000-row cap.
- Strict E.164 phone validation with helpful error message.
- POST /preview + /import + GET /template; preview UI on event page;
  atomic per-batch with dedup on existing emails.

Phone capture across UI
- PhoneInput component: country picker (~50 ISO codes) + national input +
  live E.164 preview + inline length validation.
- Used in Add Guest and Edit Guest modals. Smart paste-handling extracts
  country code from full E.164 strings.

Block F — Billing (Stripe)
- Migration 0005: subscriptions table (user_id → tier/status/period_end +
  Stripe customer/sub ids). Partial unique index keeps one granting sub
  per user.
- internal/billing: Tier + Limits model (Free 1/50, Pro 10/1000, Business
  ∞/5000), Stripe SDK wrapper with IgnoreAPIVersionMismatch for newer
  account API versions.
- /billing/checkout-session, /billing/portal, /billing/status,
  /webhooks/stripe (signature-verified, lifecycle events).
- Tier enforcement: 402 on POST /events, /guests, /import with
  {error, reason, tier, used, limit, upgrade_url} body.
- Frontend: useBilling composable, /dashboard/billing page (current plan,
  usage bars, tier cards), global UpgradeModal triggered by useApi's
  402 interceptor.
- Customer portal kept for self-service cancel/payment-method changes.

Block G — Backups & DR (application side)
- Every migration has a tested .down.sql.
- TestMigrationRoundtrip applies all ups → all downs → all ups against a
  fresh container; catches asymmetric down migrations.
- cmd/restore-verify: 28-check post-restore invariant tool (schema
  presence, no orphans across 10 FK relationships, email uniqueness,
  single-active subscription, row-count snapshot).
- docs/RUNBOOK_RESTORE.md: 9-step restore procedure with RTO/RPO
  targets, drill instructions, rollback path.

Block H — Privacy compliance (application side)
- Migration 0006: deleted_at + terms_accepted_at + privacy_policy_accepted_at
  on users. Partial index on email for live-only uniqueness.
- GET /me/data-export — synchronous JSON dump (user, events, guests,
  tokens, rsvps, access_logs, notifications).
- DELETE /me — soft-delete with PII scrub + refresh-token revocation;
  re-signup with same email works.
- POST /me/accept-terms — idempotent consent recording.
- Frontend /privacy + /terms placeholder pages with substantive (pending
  legal review) copy; footer links; signup terms checkbox; TermsGateModal
  for accounts created before the rollout; export + delete buttons on
  /dashboard/billing.

Tests
- All migrations verified up/down/up.
- Integration suite: TestE2EHappyPath, TestAuthFlow, TestCrossTenantIsolation,
  TestRateLimitSignup, TestLoginLockout, TestUnsubscribeFlow,
  TestSESBounceWebhook, TestTwilioStatusWebhook, TestCsvImportFlow,
  TestCsvImportAtomicRollback, TestBulkIssueInvitations, TestBulkIssueExplicitSubset,
  TestTokenIssuePublishesInvitation, TestTokenIssueWithoutGuestEmailSkipsInvitation,
  TestGuestUpdate, TestGuestDelete, TestTokenRotate, TestSMTPSenderAgainstMailpit,
  TestFreeTierEventLimit, TestFreeTierGuestLimit, TestBusinessTierBypassesLimits,
  TestDataExport, TestDeleteMe, TestAcceptTerms, TestMigrationRoundtrip.
  Full suite runs in ~120s against real Postgres + NATS + Redis + Mailpit.
- Unit suite green across internal/auth, internal/csvimport,
  internal/notification, internal/ratelimit, internal/domain.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Kwaku Danso
2026-05-16 23:54:22 +01:00
parent a0ed34f860
commit 59b8781659
124 changed files with 13702 additions and 445 deletions
+3
View File
@@ -11,6 +11,9 @@ coverage.*
.env
.env.local
# Agent / per-developer config (launch.json with absolute paths, worktree state).
.claude/
.DS_Store
.idea/
.vscode/
+95 -4
View File
@@ -11,10 +11,15 @@ import (
"syscall"
"time"
"github.com/redis/go-redis/v9"
"github.com/alchemistkay/guestguard/internal/api"
"github.com/alchemistkay/guestguard/internal/auth"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/config"
"github.com/alchemistkay/guestguard/internal/fraud"
"github.com/alchemistkay/guestguard/internal/natspub"
"github.com/alchemistkay/guestguard/internal/notification"
"github.com/alchemistkay/guestguard/internal/storage"
)
@@ -56,6 +61,16 @@ func run() error {
}
defer natsClient.Close()
logger.Info("connecting to redis", "addr", cfg.RedisAddr)
rdb := redis.NewClient(&redis.Options{Addr: cfg.RedisAddr})
if err := rdb.Ping(rootCtx).Err(); err != nil {
logger.Warn("redis ping failed — rate limits + lockout disabled", "err", err)
_ = rdb.Close()
rdb = nil
} else {
defer rdb.Close()
}
logger.Info("dialing fraud engine", "addr", cfg.FraudGRPCAddr)
fraudClient, err := fraud.Dial(rootCtx, cfg.FraudGRPCAddr, cfg.FraudGRPCTimeout, logger)
if err != nil {
@@ -118,17 +133,93 @@ func run() error {
}
defer rsvpConsumeCtx.Stop()
srv := &http.Server{
Addr: cfg.HTTPAddr,
Handler: api.NewServer(api.ServerDeps{
// Notification senders. If SES creds are configured, route auth +
// guest emails through SES. Otherwise the log stub keeps the dev flow
// (verification link in API logs) intact.
tpls, err := notification.NewTemplates()
if err != nil {
return err
}
suppressions := notification.NewSuppressionRepo(db)
notifRepo := notification.NewRepo(db)
unsubSigner := notification.NewUnsubscribeSigner(cfg.UnsubscribeSecret)
emailSenderCombined, backend, err := notification.PickEmailSender(rootCtx, notification.EmailSenderConfig{
Resend: notification.ResendConfig{
APIKey: cfg.ResendAPIKey,
FromEmail: cfg.ResendFromEmail,
FromName: cfg.ResendFromName,
},
SMTP: notification.SMTPConfig{
Host: cfg.SMTPHost,
Port: cfg.SMTPPort,
Username: cfg.SMTPUsername,
Password: cfg.SMTPPassword,
FromEmail: cfg.SMTPFromEmail,
FromName: cfg.SMTPFromName,
TLS: cfg.SMTPTLS,
},
SES: notification.SESConfig{
Region: cfg.SESRegion,
FromEmail: cfg.SESFromEmail,
FromName: cfg.SESFromName,
ConfigurationSet: cfg.SESConfigurationSet,
PublicBaseURL: cfg.PublicBaseURL,
},
}, tpls, logger)
if err != nil {
return err
}
logger.Info("email backend selected", "backend", backend)
var emailSender auth.EmailSender = emailSenderCombined
stripeClient, err := billing.NewClient(billing.Config{
SecretKey: cfg.StripeSecretKey,
WebhookSecret: cfg.StripeWebhookSecret,
PriceProMonthly: cfg.StripePricePro,
PriceBusiness: cfg.StripePriceBusiness,
})
if err != nil {
return err
}
if stripeClient != nil && stripeClient.Enabled() {
logger.Info("billing enabled via stripe")
} else {
logger.Info("billing disabled — free tier limits apply to all users")
}
apiSrv, err := api.NewServer(api.ServerDeps{
Logger: logger,
DB: db,
Hub: hub,
AccessPublisher: natsClient,
RSVPPublisher: natsClient,
InvitationPublisher: natsClient,
FraudScorer: fraudClient,
TokenTTL: cfg.TokenTTL,
}).Handler(),
JWTSecret: cfg.JWTSecret,
JWTIssuer: cfg.JWTIssuer,
AccessTokenTTL: cfg.AccessTokenTTL,
RefreshTokenTTL: cfg.RefreshTokenTTL,
EmailVerificationTTL: cfg.EmailVerificationTTL,
PasswordResetTTL: cfg.PasswordResetTTL,
PublicBaseURL: cfg.PublicBaseURL,
RefreshCookieDomain: cfg.RefreshCookieDomain,
RefreshCookieSecure: cfg.RefreshCookieSecure,
Redis: rdb,
EmailSender: emailSender,
NotificationRepo: notifRepo,
SuppressionRepo: suppressions,
UnsubscribeSigner: unsubSigner,
StripeClient: stripeClient,
})
if err != nil {
return err
}
srv := &http.Server{
Addr: cfg.HTTPAddr,
Handler: apiSrv.Handler(),
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // 0 lets WS connections live; per-request handlers still bound by their own ctx
+141 -1
View File
@@ -48,7 +48,61 @@ func run() error {
defer natsClient.Close()
repo := notification.NewRepo(db)
sender := notification.LogSender{}
suppressions := notification.NewSuppressionRepo(db)
tpls, err := notification.NewTemplates()
if err != nil {
return err
}
// Email dispatcher: Resend > SMTP > SES > log stub, same picker as cmd/api.
combinedEmail, backend, err := notification.PickEmailSender(rootCtx, notification.EmailSenderConfig{
Resend: notification.ResendConfig{
APIKey: cfg.ResendAPIKey,
FromEmail: cfg.ResendFromEmail,
FromName: cfg.ResendFromName,
},
SMTP: notification.SMTPConfig{
Host: cfg.SMTPHost,
Port: cfg.SMTPPort,
Username: cfg.SMTPUsername,
Password: cfg.SMTPPassword,
FromEmail: cfg.SMTPFromEmail,
FromName: cfg.SMTPFromName,
TLS: cfg.SMTPTLS,
},
SES: notification.SESConfig{
Region: cfg.SESRegion,
FromEmail: cfg.SESFromEmail,
FromName: cfg.SESFromName,
ConfigurationSet: cfg.SESConfigurationSet,
PublicBaseURL: cfg.PublicBaseURL,
},
}, tpls, logger)
if err != nil {
return err
}
logger.Info("email backend selected", "backend", backend)
emailSender := notification.NewEmailSender(combinedEmail, suppressions)
// SMS: Twilio when creds are set, otherwise no-op log sender.
var smsSender notification.Sender
if cfg.TwilioAccountSID != "" && cfg.TwilioAuthToken != "" && cfg.TwilioFromNumber != "" {
t, err := notification.NewTwilioSender(notification.TwilioConfig{
AccountSID: cfg.TwilioAccountSID,
AuthToken: cfg.TwilioAuthToken,
FromNumber: cfg.TwilioFromNumber,
})
if err != nil {
return err
}
smsSender = t
logger.Info("twilio sms sender configured", "from", cfg.TwilioFromNumber)
} else {
smsSender = notification.LogSender{}
logger.Info("twilio not configured — SMS will use the log stub")
}
sender := notification.NewRouter(emailSender, smsSender)
rsvpSub, err := natspub.NewRSVPConfirmedSubscriber(
rootCtx, natsClient, "notifier-rsvp-confirmed",
@@ -82,6 +136,22 @@ func run() error {
}
defer fraudCC.Stop()
invitationSub, err := natspub.NewInvitationSendSubscriber(
rootCtx, natsClient, "notifier-invitation-send",
func(ctx context.Context, evt natspub.InvitationSend) error {
return handleInvitationSend(ctx, logger, repo, sender, evt)
},
logger,
)
if err != nil {
return err
}
invitationCC, err := invitationSub.Start(rootCtx)
if err != nil {
return err
}
defer invitationCC.Stop()
logger.Info("notifier started")
<-rootCtx.Done()
logger.Info("notifier shutting down")
@@ -201,6 +271,76 @@ func handleFraudScored(
return nil
}
// handleInvitationSend renders the invitation template + sends through
// the configured sender (Resend in prod, Mailpit in dev), then writes a
// notification row so the host can audit the delivery history.
func handleInvitationSend(
ctx context.Context,
logger *slog.Logger,
repo *notification.Repo,
sender notification.Sender,
evt natspub.InvitationSend,
) error {
if evt.GuestEmail == "" {
// Nothing to do — host-managed delivery.
return nil
}
eventDate := ""
if !evt.EventDate.IsZero() {
eventDate = evt.EventDate.Format("Mon, 02 Jan 2006 · 15:04")
}
msg := notification.OutboundMessage{
GuestID: evt.GuestID,
Channel: notification.ChannelEmail,
Type: notification.TypeInvitation,
Subject: "You're invited — " + evt.EventName,
Metadata: map[string]any{
"to": evt.GuestEmail,
"GuestName": evt.GuestName,
"HostName": evt.HostName,
"EventName": evt.EventName,
"Venue": evt.Venue,
"EventDate": eventDate,
"Link": evt.Link,
},
}
providerID, sendErr := sender.Send(ctx, msg)
status := notification.StatusSent
errStr := ""
if sendErr != nil {
status = notification.StatusFailed
errStr = sendErr.Error()
}
id, err := repo.Record(ctx, notification.RecordParams{
GuestID: evt.GuestID,
Channel: msg.Channel,
Type: msg.Type,
Status: status,
ProviderMessageID: providerID,
Error: errStr,
})
if err != nil {
return err
}
logger.Info("invitation dispatched",
"notification_id", id,
"guest_id", evt.GuestID,
"event_id", evt.EventID,
"to", evt.GuestEmail,
"status", status,
"provider_message_id", providerID,
)
if sendErr != nil {
return sendErr
}
return nil
}
func levelFor(env string) slog.Level {
if env == "development" {
return slog.LevelDebug
+267
View File
@@ -0,0 +1,267 @@
// restore-verify is a post-restore sanity tool. Point it at a freshly
// restored Postgres instance and it asserts that the schema and data
// are coherent — no orphan rows, expected uniqueness invariants hold,
// every table is present. Exits non-zero on any failure so it slots
// into a "restore drill" runbook step.
//
// Usage:
// GG_DATABASE_URL=postgres://... ./restore-verify [--verbose]
//
// The intent is "would I bet my Sunday on this restore being usable?".
// Failing fast here keeps a bad restore from being promoted to traffic.
package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
func main() {
verbose := flag.Bool("verbose", false, "print every check's result, not just failures")
flag.Parse()
dsn := os.Getenv("GG_DATABASE_URL")
if dsn == "" {
fmt.Fprintln(os.Stderr, "restore-verify: GG_DATABASE_URL is required")
os.Exit(2)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
fmt.Fprintf(os.Stderr, "restore-verify: connect: %v\n", err)
os.Exit(2)
}
defer pool.Close()
if err := pool.Ping(ctx); err != nil {
fmt.Fprintf(os.Stderr, "restore-verify: ping: %v\n", err)
os.Exit(2)
}
fmt.Println("restore-verify: checking", maskDSN(dsn))
fmt.Println()
checks := allChecks()
var failed []string
for _, c := range checks {
result, err := c.fn(ctx, pool)
if err != nil {
failed = append(failed, c.name)
fmt.Printf(" ✗ %-50s FAIL: %v\n", c.name, err)
continue
}
if *verbose {
fmt.Printf(" ✓ %-50s %s\n", c.name, result)
}
}
fmt.Println()
if len(failed) > 0 {
fmt.Printf("FAILED: %d check%s — %s\n",
len(failed), pluralS(len(failed)), strings.Join(failed, ", "))
os.Exit(1)
}
fmt.Printf("OK: all %d checks passed\n", len(checks))
}
type check struct {
name string
fn func(ctx context.Context, pool *pgxpool.Pool) (string, error)
}
func allChecks() []check {
return []check{
// --- schema presence ---
tableExists("users"),
tableExists("events"),
tableExists("guests"),
tableExists("tokens"),
tableExists("rsvps"),
tableExists("access_logs"),
tableExists("notifications"),
tableExists("schema_migrations"),
tableExists("email_verification_tokens"),
tableExists("password_reset_tokens"),
tableExists("refresh_tokens"),
tableExists("unsubscribes"),
tableExists("subscriptions"),
// --- migrations applied ---
{
name: "schema_migrations: 5+ migrations applied",
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
var n int
if err := pool.QueryRow(ctx,
`SELECT count(*) FROM schema_migrations`).Scan(&n); err != nil {
return "", err
}
if n < 5 {
return "", fmt.Errorf("only %d migrations recorded — incomplete restore", n)
}
return fmt.Sprintf("%d migrations", n), nil
},
},
// --- referential integrity (FK constraints catch most of this,
// but a bad logical dump or partial restore can sneak rows in) ---
noOrphans("events", "host_id", "users", "id"),
noOrphans("guests", "event_id", "events", "id"),
noOrphans("tokens", "guest_id", "guests", "id"),
noOrphans("rsvps", "guest_id", "guests", "id"),
noOrphans("access_logs", "guest_id", "guests", "id"),
noOrphans("notifications", "guest_id", "guests", "id"),
noOrphans("email_verification_tokens", "user_id", "users", "id"),
noOrphans("password_reset_tokens", "user_id", "users", "id"),
noOrphans("refresh_tokens", "user_id", "users", "id"),
noOrphans("subscriptions", "user_id", "users", "id"),
// --- domain invariants ---
{
name: "users.email is unique (case-insensitive)",
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
var dupes int
err := pool.QueryRow(ctx, `
SELECT count(*) FROM (
SELECT lower(email) FROM users GROUP BY lower(email) HAVING count(*) > 1
) t
`).Scan(&dupes)
if err != nil {
return "", err
}
if dupes > 0 {
return "", fmt.Errorf("%d duplicate email(s) found", dupes)
}
return "no duplicate emails", nil
},
},
{
name: "guests with rsvp_response have an existing rsvp row",
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
var n int
err := pool.QueryRow(ctx, `
SELECT count(*) FROM rsvps r
WHERE NOT EXISTS (SELECT 1 FROM guests g WHERE g.id = r.guest_id)
`).Scan(&n)
if err != nil {
return "", err
}
if n > 0 {
return "", fmt.Errorf("%d rsvp(s) reference a missing guest", n)
}
return "0 orphan rsvps", nil
},
},
{
name: "at most one granting subscription per user",
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
var n int
err := pool.QueryRow(ctx, `
SELECT count(*) FROM (
SELECT user_id FROM subscriptions
WHERE status IN ('active','past_due','trialing')
GROUP BY user_id HAVING count(*) > 1
) t
`).Scan(&n)
if err != nil {
return "", err
}
if n > 0 {
return "", fmt.Errorf("%d user(s) have multiple granting subscriptions", n)
}
return "all users single-active", nil
},
},
// --- soft constraints worth noticing (not failures, but logged) ---
{
name: "row counts snapshot",
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
parts := []string{}
for _, t := range []string{
"users", "events", "guests", "tokens", "rsvps",
"access_logs", "notifications", "subscriptions",
} {
var n int
if err := pool.QueryRow(ctx,
fmt.Sprintf("SELECT count(*) FROM %s", t)).Scan(&n); err != nil {
return "", err
}
parts = append(parts, fmt.Sprintf("%s=%d", t, n))
}
return strings.Join(parts, " "), nil
},
},
}
}
func tableExists(name string) check {
return check{
name: fmt.Sprintf("table %q exists", name),
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
var exists bool
err := pool.QueryRow(ctx,
`SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema='public' AND table_name=$1)`,
name,
).Scan(&exists)
if err != nil {
return "", err
}
if !exists {
return "", errors.New("missing")
}
return "ok", nil
},
}
}
func noOrphans(childTable, childFK, parentTable, parentPK string) check {
return check{
name: fmt.Sprintf("no orphans: %s.%s -> %s.%s", childTable, childFK, parentTable, parentPK),
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
q := fmt.Sprintf(`
SELECT count(*) FROM %s c
WHERE c.%s IS NOT NULL
AND NOT EXISTS (SELECT 1 FROM %s p WHERE p.%s = c.%s)
`, childTable, childFK, parentTable, parentPK, childFK)
var n int
if err := pool.QueryRow(ctx, q).Scan(&n); err != nil {
return "", err
}
if n > 0 {
return "", fmt.Errorf("%d orphan row(s)", n)
}
return "clean", nil
},
}
}
func maskDSN(dsn string) string {
// Crude: redact password between '://user:' and '@'.
at := strings.LastIndex(dsn, "@")
scheme := strings.Index(dsn, "://")
if at < 0 || scheme < 0 || at <= scheme {
return dsn
}
userInfo := dsn[scheme+3 : at]
if colon := strings.Index(userInfo, ":"); colon >= 0 {
userInfo = userInfo[:colon] + ":****"
}
return dsn[:scheme+3] + userInfo + dsn[at:]
}
func pluralS(n int) string {
if n == 1 {
return ""
}
return "s"
}
+50
View File
@@ -15,6 +15,28 @@ services:
timeout: 3s
retries: 10
mailpit:
image: axllent/mailpit:latest
ports:
- "1025:1025" # SMTP
- "8025:8025" # Web UI
healthcheck:
test: ["CMD", "wget", "-qO-", "http://localhost:8025/api/v1/info"]
interval: 5s
timeout: 3s
retries: 10
redis:
image: redis:7-alpine
command: ["redis-server", "--save", "", "--appendonly", "no"]
ports:
- "6379:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 10
nats:
image: nats:2.10-alpine
command:
@@ -45,7 +67,25 @@ services:
GG_NATS_URL: nats://nats:4222
GG_FRAUD_GRPC_ADDR: fraud-engine:9091
GG_FRAUD_GRPC_TIMEOUT: 250ms
GG_REDIS_ADDR: redis:6379
GG_SMTP_HOST: mailpit
GG_SMTP_PORT: "1025"
GG_SMTP_FROM_EMAIL: noreply@guestguard.local
GG_SMTP_FROM_NAME: GuestGuard (dev)
# Resend overrides SMTP when GG_RESEND_API_KEY is set in .env at the
# project root. Leave unset and Mailpit (above) handles delivery.
GG_RESEND_API_KEY: ${GG_RESEND_API_KEY:-}
GG_RESEND_FROM_EMAIL: ${GG_RESEND_FROM_EMAIL:-}
GG_RESEND_FROM_NAME: ${GG_RESEND_FROM_NAME:-GuestGuard}
GG_TOKEN_SECRET: dev-only-insecure-secret-change-me
GG_JWT_SECRET: dev-only-insecure-jwt-secret-change-me-32+bytes
GG_PUBLIC_BASE_URL: http://localhost:3000
# Stripe billing — empty values leave billing disabled. Set in
# .env at the project root to enable.
GG_STRIPE_SECRET_KEY: ${GG_STRIPE_SECRET_KEY:-}
GG_STRIPE_WEBHOOK_SECRET: ${GG_STRIPE_WEBHOOK_SECRET:-}
GG_STRIPE_PRICE_PRO: ${GG_STRIPE_PRICE_PRO:-}
GG_STRIPE_PRICE_BUSINESS: ${GG_STRIPE_PRICE_BUSINESS:-}
ports:
- "8080:8080"
depends_on:
@@ -53,6 +93,8 @@ services:
condition: service_healthy
nats:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
fraud-engine:
@@ -80,6 +122,14 @@ services:
GG_ENV: development
GG_DATABASE_URL: postgres://guestguard:guestguard@postgres:5432/guestguard?sslmode=disable
GG_NATS_URL: nats://nats:4222
GG_PUBLIC_BASE_URL: http://localhost:3000
GG_SMTP_HOST: mailpit
GG_SMTP_PORT: "1025"
GG_SMTP_FROM_EMAIL: noreply@guestguard.local
GG_SMTP_FROM_NAME: GuestGuard (dev)
GG_RESEND_API_KEY: ${GG_RESEND_API_KEY:-}
GG_RESEND_FROM_EMAIL: ${GG_RESEND_FROM_EMAIL:-}
GG_RESEND_FROM_NAME: ${GG_RESEND_FROM_NAME:-GuestGuard}
depends_on:
postgres:
condition: service_healthy
+251
View File
@@ -0,0 +1,251 @@
# Runbook — Postgres restore
This is the procedure to bring GuestGuard back from a Postgres backup
after data loss. It assumes the infra side of Block G (`pg_basebackup` +
WAL archiving to S3, daily logical dumps, cross-region replication) is
already in place — see the homelab repo for those.
The application side — migration down-scripts, the [`restore-verify`](../cmd/restore-verify/main.go)
tool, and this document — lives here in the GuestGuard repo so it ships
in lockstep with the schema.
---
## Targets
| Metric | Target |
|---|---|
| RTO (recovery time objective) | ≤ 1 hour from "go" decision to traffic-serving |
| RPO (recovery point objective) | ≤ 5 minutes of data loss (WAL ships every 60s, S3 PUT every 5min) |
If RTO is going to slip past 1 hour, escalate per the comms plan in `docs/INCIDENT_RESPONSE.md` (infra repo).
## When to invoke this
- Primary Postgres is unreachable AND the standby has also failed
- Logical corruption discovered (e.g., a bad migration deleted rows)
- Region-wide outage at the primary's location
- A "what if we restored last Tuesday" drill (see [Drill](#drill-procedure))
If only the primary is unreachable and the standby is healthy, promote
the standby (separate runbook). Don't use this procedure unnecessarily —
restores are expensive.
## Prerequisites
Before starting:
- [ ] Decision authority has approved the restore (CTO or on-call lead)
- [ ] Read access to the S3 backup bucket: `s3://guestguard-pg-backups`
- [ ] `psql`, `pg_basebackup`, `wal-g` (or chosen WAL tool) installed
- [ ] Empty target Postgres instance provisioned (Kubernetes Statefulset,
RDS, or homelab box — same major version as the backup)
- [ ] `GG_DATABASE_URL` env var ready for the new instance
- [ ] Maintenance page deployed to the frontend (`/dashboard` returns 503)
- [ ] API + notifier pods scaled to 0 (`kubectl scale --replicas=0`)
- [ ] This document open in another tab
## Steps
### 1. Stop write traffic
```bash
# k8s
kubectl scale deployment/guestguard-api --replicas=0
kubectl scale deployment/guestguard-notifier --replicas=0
# Confirm no connections to the (broken) primary
kubectl exec -n postgres guestguard-pg-0 -- psql -U postgres -c \
"SELECT count(*) FROM pg_stat_activity WHERE datname='guestguard'"
```
If using docker-compose locally: `docker compose stop api notifier`.
### 2. Identify the recovery point
Pick the latest backup that's known-good. For corruption scenarios,
this may mean going further back than the most recent dump.
```bash
# List base backups (most recent first)
wal-g backup-list 2>/dev/null | tail -10
# Pick the timestamp (e.g. base_000000010000000000000007) and decide
# the LSN target if doing point-in-time recovery
```
For corruption: pick the latest backup created **before** the corrupting
event. For "ransomware / bad migration", probably 12 days back.
### 3. Restore the base backup
```bash
# Replace BACKUP_NAME with the chosen base
wal-g backup-fetch /var/lib/postgresql/data BACKUP_NAME
# Configure recovery target (omit recovery_target_time for "latest")
cat >> /var/lib/postgresql/data/postgresql.conf <<EOF
restore_command = 'wal-g wal-fetch "%f" "%p"'
recovery_target_time = '2026-05-13 14:30:00 UTC' # set if doing PITR
EOF
touch /var/lib/postgresql/data/recovery.signal
```
### 4. Start Postgres and let it replay WAL
```bash
systemctl start postgresql # or your equivalent
# Watch the log — should see "consistent recovery state reached"
tail -f /var/log/postgresql/postgresql-16-main.log
```
Wait until recovery completes and Postgres is in normal (not recovery)
mode:
```bash
psql -c "SELECT pg_is_in_recovery()"
# Expected: f (false)
```
### 5. Verify the restored database
This is the critical gate before any application traffic touches it.
```bash
# Build the verifier (only needed once)
go build -o restore-verify ./cmd/restore-verify
# Run it against the restored instance
GG_DATABASE_URL='postgres://guestguard:CHANGE_ME@RESTORED_HOST:5432/guestguard?sslmode=require' \
./restore-verify --verbose
```
Expected output: `OK: all N checks passed`. The tool checks:
- All expected tables exist (users, events, guests, tokens, rsvps, etc.)
- All migrations recorded in `schema_migrations`
- No orphan rows across the ten FK relationships we care about
- `users.email` is still unique (case-insensitive)
- No more than one "granting" subscription per user
- Row-count snapshot (for sanity, not pass/fail)
**If any check fails: STOP.** The restore is corrupt — go back to step 2
with an earlier backup OR escalate.
### 6. Apply pending migrations
If the backup is from before a recent migration that shipped to prod,
catch up:
```bash
# The API auto-migrates on boot, but we want to apply migrations
# before traffic, so kick a one-off:
docker run --rm \
-e GG_DATABASE_URL='postgres://...' \
ghcr.io/alchemistkay/guestguard-api:latest \
/app/api --migrate-only
# Or via psql, applying each .up.sql in order if you don't have the image:
for m in 0001_init 0002_rsvps 0003_auth 0004_notifications_d 0005_billing; do
psql -f internal/storage/migrations/${m}.up.sql || break
done
```
Run `restore-verify` again after migrations to confirm everything's
still coherent.
### 7. Bring the API back up
```bash
kubectl scale deployment/guestguard-api --replicas=2
kubectl scale deployment/guestguard-notifier --replicas=1
# Watch the logs — expect "http server starting" + "billing enabled via stripe"
kubectl logs -f deployment/guestguard-api --tail=20
```
### 8. Smoke test
- [ ] Hit `/health` → 200
- [ ] Sign in as a known test user → dashboard loads, recent events visible
- [ ] Create a new event → succeeds, appears in list
- [ ] Tail API logs for 5 minutes → no 5xx storms
### 9. Re-enable traffic
- [ ] Remove the maintenance page from the frontend
- [ ] Announce restoration in the status channel + status page
- [ ] Note actual RTO + RPO achieved for the post-mortem
## Drill procedure
Run this monthly with no real outage to keep the team's hands warm.
1. Provision a throwaway Postgres instance (`postgres-drill-YYYYMM`).
2. Run steps 25 against it (skip 1, 7, 8, 9 — production stays untouched).
3. `restore-verify` MUST pass.
4. Bonus: spin up an API pointed at the drill DB on a one-off port and
walk through the smoke-test scenarios.
5. Tear down the drill DB.
6. Log the time taken in `docs/RESTORE_DRILL_LOG.md` (or wherever your
team tracks operational drills).
If any step fails during a drill, the production fail-over procedure is
**unreliable** — treat as a P1 to fix before the next real failure.
## Rollback (if restore is wrong)
If you complete the restore and discover it's the wrong data:
1. Scale API back to 0
2. Find the next earlier backup
3. Drop and recreate the database on the restored instance
4. Repeat from step 3
**Never** point production at a known-bad restored DB hoping to fix it
later — the API will write new data on top of the corruption and the
salvage gets exponentially harder.
## Migration down-scripts
Every `.up.sql` in `internal/storage/migrations/` has a matching
`.down.sql`. They're tested as part of CI and not exercised during
normal restores (the up-only sequence in step 6 is the path used).
They exist for:
- Drill scenarios where you want to "rewind" the schema
- Emergency rollback of a bad shipped migration
Down-script integrity: run the `TestMigrationRoundtrip` integration
test, which applies every migration up → down → up against a fresh
container.
## Application config that supports restored DBs
`GG_DATABASE_URL` is the single source of truth — no hardcoded
hostnames anywhere in the codebase. Verified by:
```bash
grep -rE 'postgres://|host=.*5432' --include='*.go' . | grep -v _test.go | grep -v config.go
# Expected: (empty)
```
If anything surfaces from that grep, file a bug — it'll bite the next
restore.
## Escalation
| Step fails | Who to call |
|---|---|
| Steps 14 | On-call infra lead |
| Step 5 (`restore-verify`) | On-call backend lead + DBA |
| Steps 78 (app won't start / smoke fails) | On-call backend lead |
| Drill failure | File P1 ticket, link the drill log |
## Change log
| Date | Author | Change |
|---|---|---|
| 2026-05-16 | kay | initial version (Block G) |
+26 -8
View File
@@ -1,13 +1,13 @@
<script setup lang="ts">
const { host, clear } = useHost()
const auth = useAuth()
const route = useRoute()
// GitHub icon is a "marketing" affordance — only show it on the public landing
// page. Inside the app it just clutters the chrome.
const showGithub = computed(() => route.path === '/')
function logout() {
clear()
async function signOut() {
await auth.logout()
navigateTo('/')
}
</script>
@@ -34,12 +34,17 @@ function logout() {
<path d="M12 0C5.374 0 0 5.373 0 12c0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23A11.509 11.509 0 0112 5.803c1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576C20.566 21.797 24 17.3 24 12c0-6.627-5.373-12-12-12z"/>
</svg>
</a>
<template v-if="host">
<button class="transition hover:text-zinc-100" @click="logout">Sign out</button>
<ClientOnly>
<template v-if="auth.isAuthenticated.value">
<NuxtLink to="/dashboard" class="transition hover:text-zinc-100">Dashboard</NuxtLink>
<NuxtLink to="/dashboard/billing" class="transition hover:text-zinc-100">Billing</NuxtLink>
<button class="transition hover:text-zinc-100" @click="signOut">Sign out</button>
</template>
<template v-else>
<NuxtLink to="/dashboard" class="transition hover:text-zinc-100">Sign in</NuxtLink>
<NuxtLink to="/login" class="transition hover:text-zinc-100">Sign in</NuxtLink>
<NuxtLink to="/signup" class="btn-primary !px-3 !py-1.5 text-xs">Get started</NuxtLink>
</template>
</ClientOnly>
</nav>
</div>
</header>
@@ -48,9 +53,22 @@ function logout() {
<NuxtPage />
</main>
<!-- Global "plan limit reached" prompt surfaces whenever any API
call returns 402. Lives at app-root so every page benefits
without per-page wiring. -->
<UpgradeModal />
<!-- Privacy / terms onboarding gate. Auto-shows when the signed-in
user hasn't accepted the current policies yet. No-op otherwise. -->
<TermsGateModal />
<footer class="mt-16 border-t border-zinc-900">
<div class="mx-auto max-w-6xl px-6 py-6 text-xs text-zinc-500">
© 2025 GuestGuard Hassle-free RSVPs for every occasion.
<div class="mx-auto flex max-w-6xl flex-wrap items-center justify-between gap-3 px-6 py-6 text-xs text-zinc-500">
<span>© 2025 GuestGuard Hassle-free RSVPs for every occasion.</span>
<span class="flex items-center gap-4">
<NuxtLink to="/privacy" class="hover:text-zinc-300">Privacy</NuxtLink>
<NuxtLink to="/terms" class="hover:text-zinc-300">Terms</NuxtLink>
</span>
</div>
</footer>
</div>
+217
View File
@@ -0,0 +1,217 @@
<script setup lang="ts">
interface ParsedRow {
Name: string
Email: string
Phone: string
PlusOnes: number
}
interface RowError {
row: number
reason: string
}
interface PreviewResponse {
rows: ParsedRow[]
errors?: RowError[]
total_count: number
}
interface ImportResponse {
added: number
skipped: number
skipped_emails?: string[]
errors?: RowError[]
total_count: number
}
const props = defineProps<{ eventId: string }>()
const emit = defineEmits<{ (e: 'imported'): void }>()
const config = useRuntimeConfig()
const apiBase = config.public.apiBase as string
type Stage = 'idle' | 'preview' | 'committing' | 'done'
const stage = ref<Stage>('idle')
const fileName = ref('')
const dragging = ref(false)
const error = ref<string | null>(null)
const preview = ref<PreviewResponse | null>(null)
const result = ref<ImportResponse | null>(null)
let pendingFile: File | null = null
function reset() {
stage.value = 'idle'
fileName.value = ''
preview.value = null
result.value = null
error.value = null
pendingFile = null
}
async function onFiles(files: FileList | File[] | null) {
if (!files || !files.length) return
const f = files[0]
if (f.size > 1024 * 1024) {
error.value = 'File is larger than 1MB.'
return
}
fileName.value = f.name
error.value = null
pendingFile = f
await runPreview()
}
async function runPreview() {
if (!pendingFile) return
try {
const fd = new FormData()
fd.append('file', pendingFile)
preview.value = await useApi<PreviewResponse>(
`/events/${props.eventId}/guests/import/preview`,
{ method: 'POST', body: fd },
)
stage.value = 'preview'
} catch (e: any) {
error.value = useErrMessage(e, 'Could not parse that CSV')
stage.value = 'idle'
pendingFile = null
}
}
async function commit() {
if (!pendingFile) return
stage.value = 'committing'
error.value = null
try {
const fd = new FormData()
fd.append('file', pendingFile)
result.value = await useApi<ImportResponse>(
`/events/${props.eventId}/guests/import`,
{ method: 'POST', body: fd },
)
stage.value = 'done'
emit('imported')
} catch (e: any) {
error.value = useErrMessage(e, 'Import failed')
stage.value = 'preview'
}
}
function onDrop(e: DragEvent) {
dragging.value = false
onFiles(e.dataTransfer?.files ?? null)
}
const templateUrl = computed(() => `${apiBase}/events/${props.eventId}/guests/import/template`)
async function downloadTemplate() {
// The endpoint requires a Bearer header, which a plain <a download> can't
// attach — fetch through useApi (which adds auth + handles refresh) then
// synthesise an anchor click on the resulting blob.
try {
const res: any = await useApi(`/events/${props.eventId}/guests/import/template`, {
method: 'GET',
})
// useApi auto-parses JSON; for CSV the response body is plain text.
const text = typeof res === 'string' ? res : JSON.stringify(res)
const blob = new Blob([text], { type: 'text/csv;charset=utf-8' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = 'guestguard-import-template.csv'
a.click()
URL.revokeObjectURL(url)
} catch (e: any) {
error.value = useErrMessage(e, 'Could not download template')
}
}
</script>
<template>
<!-- No outer .card chrome: this component is now embedded inside a
modal that supplies the surface. -->
<div>
<div class="mb-3 flex items-center justify-end">
<button class="text-xs text-zinc-400 hover:text-zinc-200" @click="downloadTemplate">
Download template
</button>
</div>
<!-- Stage 1: drag-drop zone -->
<div v-if="stage === 'idle'">
<label
class="flex cursor-pointer flex-col items-center justify-center rounded-lg border-2 border-dashed bg-zinc-900/50 p-6 text-center transition"
:class="dragging ? 'border-brand-500 bg-brand-500/5' : 'border-zinc-700 hover:border-zinc-600'"
@dragover.prevent="dragging = true"
@dragleave.prevent="dragging = false"
@drop.prevent="onDrop"
>
<span class="mb-1 text-sm text-zinc-200">Drop a CSV here or click to choose</span>
<span class="text-xs text-zinc-500">Up to 5,000 rows · 1MB max</span>
<input type="file" accept=".csv,text/csv" class="hidden" @change="onFiles(($event.target as HTMLInputElement).files)" />
</label>
<p class="mt-2 text-xs text-zinc-500">
Required column: <code>name</code>. Optional: <code>email</code>, <code>phone</code>, <code>plus_ones</code>.
</p>
<p v-if="error" class="mt-3 text-sm text-red-400">{{ error }}</p>
</div>
<!-- Stage 2: preview -->
<div v-else-if="stage === 'preview' && preview">
<p class="mb-2 text-sm">
<span class="text-zinc-200">{{ fileName }}</span>
<span class="text-zinc-500"> {{ preview.rows.length }} valid row{{ preview.rows.length === 1 ? '' : 's' }}, {{ preview.errors?.length || 0 }} error{{ preview.errors?.length === 1 ? '' : 's' }}.</span>
</p>
<div v-if="preview.errors && preview.errors.length" class="mb-3 max-h-32 overflow-auto rounded border border-amber-900/40 bg-amber-950/20 p-2 text-xs text-amber-200">
<p class="mb-1 font-medium">Rows with problems (these will be skipped):</p>
<ul class="space-y-0.5">
<li v-for="(err, i) in preview.errors" :key="i">
row {{ err.row }} {{ err.reason }}
</li>
</ul>
</div>
<div v-if="preview.rows.length" class="mb-3 max-h-64 overflow-auto rounded border border-zinc-800">
<table class="w-full text-xs">
<thead class="bg-zinc-900 text-zinc-400">
<tr>
<th class="px-2 py-1.5 text-left">Name</th>
<th class="px-2 py-1.5 text-left">Email</th>
<th class="px-2 py-1.5 text-left">Phone</th>
<th class="px-2 py-1.5 text-right">+1s</th>
</tr>
</thead>
<tbody class="divide-y divide-zinc-800">
<tr v-for="(r, i) in preview.rows.slice(0, 100)" :key="i" class="text-zinc-200">
<td class="px-2 py-1">{{ r.Name }}</td>
<td class="px-2 py-1 text-zinc-400">{{ r.Email || '—' }}</td>
<td class="px-2 py-1 text-zinc-400">{{ r.Phone || '—' }}</td>
<td class="px-2 py-1 text-right tabular-nums">{{ r.PlusOnes }}</td>
</tr>
</tbody>
</table>
<p v-if="preview.rows.length > 100" class="px-2 py-1 text-xs text-zinc-500">
+ {{ preview.rows.length - 100 }} more not shown.
</p>
</div>
<div class="flex items-center gap-2">
<button class="btn-primary" :disabled="!preview.rows.length || stage === 'committing'" @click="commit">
Looks good import {{ preview.rows.length }} guest{{ preview.rows.length === 1 ? '' : 's' }}
</button>
<button class="btn-ghost" @click="reset">Cancel</button>
</div>
<p v-if="error" class="mt-3 text-sm text-red-400">{{ error }}</p>
</div>
<!-- Stage 3: committing -->
<div v-else-if="stage === 'committing'" class="text-sm text-zinc-400">Importing</div>
<!-- Stage 4: done -->
<div v-else-if="stage === 'done' && result" class="text-sm">
<p class="mb-1 font-medium text-brand-300">Imported {{ result.added }} guest{{ result.added === 1 ? '' : 's' }}.</p>
<p v-if="result.skipped" class="text-zinc-400">Skipped {{ result.skipped }} duplicate{{ result.skipped === 1 ? '' : 's' }} (already on this event).</p>
<p v-if="result.errors?.length" class="text-zinc-400">{{ result.errors.length }} row{{ result.errors.length === 1 ? '' : 's' }} had errors and were not imported.</p>
<button class="btn-ghost mt-3" @click="reset">Import another file</button>
</div>
</div>
</template>
+240
View File
@@ -0,0 +1,240 @@
<script setup lang="ts">
// PhoneInput — country code picker + national digits input.
//
// Emits an E.164 string via v-model (e.g. "+233244123456"). Empty input
// emits an empty string. The country list covers the ~50 most likely
// origins for event guests; uncommon ones can still be typed by picking
// the closest country and entering the full number — the live preview
// shows the host what's being saved.
//
// Industry-standard UX: WhatsApp / Stripe / airline-booking pattern.
interface Country {
code: string // ISO-3166 alpha-2
name: string
dialCode: string // includes leading "+"
}
// Curated list — alphabetical-by-name, weighted toward common GuestGuard
// audience (UK/EU + Africa + diaspora destinations). The host can always
// type the full international number into another country's row if their
// country isn't here, but most won't need to.
const COUNTRIES: Country[] = [
{ code: 'AR', name: 'Argentina', dialCode: '+54' },
{ code: 'AU', name: 'Australia', dialCode: '+61' },
{ code: 'AT', name: 'Austria', dialCode: '+43' },
{ code: 'BE', name: 'Belgium', dialCode: '+32' },
{ code: 'BR', name: 'Brazil', dialCode: '+55' },
{ code: 'CM', name: 'Cameroon', dialCode: '+237' },
{ code: 'CA', name: 'Canada', dialCode: '+1' },
{ code: 'CL', name: 'Chile', dialCode: '+56' },
{ code: 'CN', name: 'China', dialCode: '+86' },
{ code: 'CI', name: "Côte d'Ivoire", dialCode: '+225' },
{ code: 'DK', name: 'Denmark', dialCode: '+45' },
{ code: 'EG', name: 'Egypt', dialCode: '+20' },
{ code: 'ET', name: 'Ethiopia', dialCode: '+251' },
{ code: 'FI', name: 'Finland', dialCode: '+358' },
{ code: 'FR', name: 'France', dialCode: '+33' },
{ code: 'DE', name: 'Germany', dialCode: '+49' },
{ code: 'GH', name: 'Ghana', dialCode: '+233' },
{ code: 'HK', name: 'Hong Kong', dialCode: '+852' },
{ code: 'IN', name: 'India', dialCode: '+91' },
{ code: 'ID', name: 'Indonesia', dialCode: '+62' },
{ code: 'IE', name: 'Ireland', dialCode: '+353' },
{ code: 'IL', name: 'Israel', dialCode: '+972' },
{ code: 'IT', name: 'Italy', dialCode: '+39' },
{ code: 'JP', name: 'Japan', dialCode: '+81' },
{ code: 'KE', name: 'Kenya', dialCode: '+254' },
{ code: 'MY', name: 'Malaysia', dialCode: '+60' },
{ code: 'MX', name: 'Mexico', dialCode: '+52' },
{ code: 'MA', name: 'Morocco', dialCode: '+212' },
{ code: 'NL', name: 'Netherlands', dialCode: '+31' },
{ code: 'NZ', name: 'New Zealand', dialCode: '+64' },
{ code: 'NG', name: 'Nigeria', dialCode: '+234' },
{ code: 'NO', name: 'Norway', dialCode: '+47' },
{ code: 'PH', name: 'Philippines', dialCode: '+63' },
{ code: 'PL', name: 'Poland', dialCode: '+48' },
{ code: 'PT', name: 'Portugal', dialCode: '+351' },
{ code: 'RW', name: 'Rwanda', dialCode: '+250' },
{ code: 'SA', name: 'Saudi Arabia', dialCode: '+966' },
{ code: 'SN', name: 'Senegal', dialCode: '+221' },
{ code: 'SG', name: 'Singapore', dialCode: '+65' },
{ code: 'ZA', name: 'South Africa', dialCode: '+27' },
{ code: 'KR', name: 'South Korea', dialCode: '+82' },
{ code: 'ES', name: 'Spain', dialCode: '+34' },
{ code: 'SE', name: 'Sweden', dialCode: '+46' },
{ code: 'CH', name: 'Switzerland', dialCode: '+41' },
{ code: 'TW', name: 'Taiwan', dialCode: '+886' },
{ code: 'TZ', name: 'Tanzania', dialCode: '+255' },
{ code: 'TH', name: 'Thailand', dialCode: '+66' },
{ code: 'TR', name: 'Turkey', dialCode: '+90' },
{ code: 'AE', name: 'UAE', dialCode: '+971' },
{ code: 'UG', name: 'Uganda', dialCode: '+256' },
{ code: 'GB', name: 'United Kingdom', dialCode: '+44' },
{ code: 'US', name: 'United States', dialCode: '+1' },
]
// Longer dial codes first so "+1" doesn't shadow "+1...". Sorted once.
const SORTED_BY_DIAL = [...COUNTRIES].sort((a, b) => b.dialCode.length - a.dialCode.length)
const props = defineProps<{
modelValue: string
/** ISO-3166 alpha-2 to use as default. Browser locale is consulted if omitted. */
defaultCountry?: string
}>()
const emit = defineEmits<{ (e: 'update:modelValue', value: string): void }>()
function findByCode(code: string): Country | undefined {
return COUNTRIES.find((c) => c.code === code.toUpperCase())
}
function parsePhone(v: string): { country: Country | null; national: string } {
if (!v) return { country: null, national: '' }
const trimmed = v.trim().replace(/[\s\-()]/g, '')
for (const c of SORTED_BY_DIAL) {
if (trimmed.startsWith(c.dialCode)) {
return { country: c, national: trimmed.slice(c.dialCode.length) }
}
}
// Fall through: leading +XX didn't match anything (or no leading +).
// Strip a leading + and a leading 0 from local-format numbers so the
// host's national digits show up cleanly in the input.
return { country: null, national: trimmed.replace(/^\+/, '').replace(/^0+/, '') }
}
function detectDefault(): Country {
if (props.defaultCountry) {
const c = findByCode(props.defaultCountry)
if (c) return c
}
if (typeof navigator !== 'undefined' && navigator.language) {
const region = navigator.language.split('-')[1] || ''
const c = findByCode(region)
if (c) return c
}
return findByCode('GB')!
}
const initial = parsePhone(props.modelValue)
const country = ref<Country>(initial.country || detectDefault())
const national = ref(initial.national)
// Digits the host typed, with leading 0s stripped (local-format helper).
const nationalDigits = computed(() => national.value.replace(/\D/g, '').replace(/^0+/, ''))
// E.164 composed from current state. Empty when no digits — keeps the
// stored value clean (don't save "+44" with no number).
const composed = computed(() => {
return nationalDigits.value ? `${country.value.dialCode}${nationalDigits.value}` : ''
})
// Inline validation. We don't try to encode per-country length rules
// (that's libphonenumber's job and overkill here) — instead we apply a
// generous floor/ceiling on the national digit count. Catches obvious
// typos without false-positives on shorter formats like Iceland (+354 7).
type Validation = 'empty' | 'short' | 'long' | 'ok'
const validation = computed<Validation>(() => {
const n = nationalDigits.value.length
if (n === 0) return 'empty'
if (n < 6) return 'short'
if (n > 13) return 'long'
return 'ok'
})
// Emit on user changes — guard reentrancy so an external prop update
// doesn't bounce back into our own watcher.
let emitting = false
watch(composed, (v) => {
emitting = true
emit('update:modelValue', v)
Promise.resolve().then(() => { emitting = false })
})
// Re-sync from prop changes (e.g., parent resets the form, edit modal
// reopens for a different guest). Skip when our own emit caused it.
watch(() => props.modelValue, (v) => {
if (emitting) return
const p = parsePhone(v)
if (p.country) country.value = p.country
national.value = p.national
})
// Smart input: if the host pastes (or types) a full E.164 with leading
// "+" into the national field, extract the country code and split it
// into the picker + the national digits. Rescues the common mistake of
// pasting "+233244123456" into the digits field instead of using the
// picker — without this, the value would get double-prefixed.
function onNationalInput(e: Event) {
const v = (e.target as HTMLInputElement).value
if (v.startsWith('+')) {
const parsed = parsePhone(v)
if (parsed.country) {
country.value = parsed.country
national.value = parsed.national
return
}
}
national.value = v
}
</script>
<template>
<div>
<div class="flex gap-2">
<select
v-model="country"
class="input w-28 shrink-0 cursor-pointer"
aria-label="Country code"
>
<option v-for="c in COUNTRIES" :key="c.code" :value="c">
{{ c.dialCode }} {{ c.code }}
</option>
</select>
<input
:value="national"
@input="onNationalInput"
type="tel"
inputmode="tel"
autocomplete="tel-national"
class="input flex-1"
:class="{
'border-amber-700/60 focus:border-amber-500 focus:ring-amber-500': validation === 'short' || validation === 'long',
}"
placeholder="Phone number"
:aria-invalid="validation === 'short' || validation === 'long' || undefined"
/>
</div>
<!-- Live feedback. Different message per validation state:
empty optional hint
short amber warning, hostsees it before pressing Save
long amber warning
ok green confirmation with the canonical E.164 visible -->
<p class="mt-1 flex items-center gap-1 text-xs">
<template v-if="validation === 'empty'">
<span class="text-zinc-500">
Optional include the country code so guests on any network can be reached.
</span>
</template>
<template v-else-if="validation === 'short'">
<svg class="h-3.5 w-3.5 shrink-0 text-amber-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fill-rule="evenodd" d="M8.485 3.495c.673-1.167 2.357-1.167 3.03 0l6.28 10.875c.673 1.167-.17 2.625-1.516 2.625H3.72c-1.347 0-2.189-1.458-1.515-2.625L8.485 3.495zM10 8a1 1 0 01.993.883L11 9v3a1 1 0 01-1.993.117L9 12V9a1 1 0 011-1zm0 6a1 1 0 110 2 1 1 0 010-2z" clip-rule="evenodd" />
</svg>
<span class="text-amber-300">Looks too short make sure you've entered all the digits.</span>
</template>
<template v-else-if="validation === 'long'">
<svg class="h-3.5 w-3.5 shrink-0 text-amber-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fill-rule="evenodd" d="M8.485 3.495c.673-1.167 2.357-1.167 3.03 0l6.28 10.875c.673 1.167-.17 2.625-1.516 2.625H3.72c-1.347 0-2.189-1.458-1.515-2.625L8.485 3.495zM10 8a1 1 0 01.993.883L11 9v3a1 1 0 01-1.993.117L9 12V9a1 1 0 011-1zm0 6a1 1 0 110 2 1 1 0 010-2z" clip-rule="evenodd" />
</svg>
<span class="text-amber-300">Looks too long — check for extra digits.</span>
</template>
<template v-else>
<svg class="h-3.5 w-3.5 shrink-0 text-brand-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd" />
</svg>
<span class="text-zinc-500">
Saved as <span class="font-mono text-zinc-300">{{ composed }}</span>
</span>
</template>
</p>
</div>
</template>
+80
View File
@@ -0,0 +1,80 @@
<script setup lang="ts">
// Shown the first time a host signs into the dashboard after T&C
// enforcement is rolled out. Existing accounts created before this
// feature don't have terms_accepted_at set — they're re-prompted once
// here, then they're set going forward.
//
// Lives in app.vue so any authed page can be the host's first stop.
const auth = useAuth()
const accepted = ref(false)
const submitting = ref(false)
const error = ref<string | null>(null)
// Only fires when the user is signed in AND we know they haven't
// accepted. /me payload includes terms_accepted_at; if absent → prompt.
const needsAcceptance = computed(() => {
const u = auth.user.value
return !!u && !u.terms_accepted_at && !u.privacy_policy_accepted_at
})
async function accept() {
if (!accepted.value) return
submitting.value = true
error.value = null
try {
await useApi('/me/accept-terms', { method: 'POST' })
// Refresh auth state so the user object reflects the new fields.
await auth.refresh()
} catch (e: any) {
error.value = useErrMessage(e, 'Could not record acceptance')
} finally {
submitting.value = false
}
}
</script>
<template>
<Teleport to="body">
<div
v-if="needsAcceptance"
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm"
>
<div
role="dialog"
aria-modal="true"
aria-labelledby="terms-title"
class="w-full max-w-md rounded-lg border border-zinc-800 bg-zinc-900 p-5 shadow-2xl"
>
<h3 id="terms-title" class="mb-1 text-base font-semibold">One quick thing</h3>
<p class="mb-4 text-sm text-zinc-400">
We've updated how GuestGuard handles your data. Before you carry on,
please confirm you've read and agree to the current policies.
</p>
<label class="flex cursor-pointer items-start gap-2 text-sm text-zinc-200">
<input
v-model="accepted"
type="checkbox"
class="mt-0.5 h-4 w-4 cursor-pointer accent-brand-500"
/>
<span>
I agree to GuestGuard's
<NuxtLink to="/terms" target="_blank" class="text-brand-400 hover:text-brand-300">Terms of Service</NuxtLink>
and
<NuxtLink to="/privacy" target="_blank" class="text-brand-400 hover:text-brand-300">Privacy Policy</NuxtLink>.
</span>
</label>
<button
class="btn-primary mt-4 w-full disabled:opacity-50"
:disabled="!accepted || submitting"
@click="accept"
>
{{ submitting ? 'Saving' : 'Continue' }}
</button>
<p v-if="error" class="mt-3 text-sm text-red-400">{{ error }}</p>
</div>
</div>
</Teleport>
</template>
+131
View File
@@ -0,0 +1,131 @@
<script setup lang="ts">
// Global "plan limit reached" modal. Bound to the upgrade-prompt state
// in useBilling, which is populated by the 402 interceptor in useApi.
// Lives in app.vue so any page that fires an API call benefits from it
// without per-page wiring.
const billing = useBilling()
const router = useRouter()
const upgrading = ref<'pro' | 'business' | null>(null)
const reasonText = computed(() => {
const p = billing.prompt.value
if (!p) return ''
if (p.reason === 'events_per_month') {
return `You've used ${p.used} of ${p.limit} events this month on the ${labelTier(p.tier)} plan.`
}
if (p.reason === 'guests_per_event') {
return `This event already has ${p.used} of ${p.limit} guests allowed on the ${labelTier(p.tier)} plan.`
}
return p.error
})
function labelTier(t: string): string {
return t.charAt(0).toUpperCase() + t.slice(1)
}
async function quickUpgrade(tier: 'pro' | 'business') {
upgrading.value = tier
try {
await billing.startCheckout(tier)
// startCheckout navigates the page — nothing else to do.
} catch {
// Fallback: send the user to the billing page so they see the error
// in context.
await router.push('/dashboard/billing')
billing.dismissUpgradePrompt()
} finally {
upgrading.value = null
}
}
function viewPlans() {
billing.dismissUpgradePrompt()
router.push('/dashboard/billing')
}
// Esc closes the modal — keeps interaction symmetrical with all the
// other dialogs across the app.
function onKeydown(e: KeyboardEvent) {
if (e.key === 'Escape' && billing.prompt.value) {
billing.dismissUpgradePrompt()
}
}
if (import.meta.client) {
onMounted(() => window.addEventListener('keydown', onKeydown))
onUnmounted(() => window.removeEventListener('keydown', onKeydown))
}
</script>
<template>
<Teleport to="body">
<div
v-if="billing.prompt.value"
class="fixed inset-0 z-50 flex items-center justify-center bg-black/60 p-4 backdrop-blur-sm"
@click.self="billing.dismissUpgradePrompt()"
>
<div
role="alertdialog"
aria-modal="true"
aria-labelledby="upgrade-title"
aria-describedby="upgrade-desc"
class="w-full max-w-md rounded-lg border border-zinc-800 bg-zinc-900 p-5 shadow-2xl"
>
<div class="mb-3 flex items-start gap-3">
<div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-brand-500/15">
<svg class="h-4 w-4 text-brand-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fill-rule="evenodd" d="M5 9V7a5 5 0 0110 0v2a2 2 0 012 2v5a2 2 0 01-2 2H5a2 2 0 01-2-2v-5a2 2 0 012-2zm8-2v2H7V7a3 3 0 016 0z" clip-rule="evenodd" />
</svg>
</div>
<div class="min-w-0">
<h3 id="upgrade-title" class="text-base font-semibold text-zinc-100">Plan limit reached</h3>
<p id="upgrade-desc" class="mt-1 text-sm text-zinc-400">{{ reasonText }}</p>
</div>
</div>
<p class="mb-4 text-sm text-zinc-300">Pick a plan to keep going your work isn't lost.</p>
<div class="space-y-2">
<button
type="button"
class="flex w-full items-center justify-between rounded-md border border-brand-700/60 bg-brand-500/10 px-3 py-3 text-left transition hover:bg-brand-500/15 disabled:opacity-50"
:disabled="upgrading !== null"
@click="quickUpgrade('pro')"
>
<span>
<span class="block text-sm font-medium text-zinc-100">Upgrade to Pro</span>
<span class="block text-xs text-zinc-500">$49 / month · 10 events · 1,000 guests per event</span>
</span>
<span class="text-xs text-zinc-400">{{ upgrading === 'pro' ? 'Opening checkout' : '' }}</span>
</button>
<button
type="button"
class="flex w-full items-center justify-between rounded-md border border-zinc-700 bg-zinc-950 px-3 py-3 text-left transition hover:border-zinc-500 hover:bg-zinc-900 disabled:opacity-50"
:disabled="upgrading !== null"
@click="quickUpgrade('business')"
>
<span>
<span class="block text-sm font-medium text-zinc-100">Upgrade to Business</span>
<span class="block text-xs text-zinc-500">$199 / month · Unlimited events · 5,000 guests per event</span>
</span>
<span class="text-xs text-zinc-400">{{ upgrading === 'business' ? 'Opening checkout' : '' }}</span>
</button>
</div>
<div class="mt-4 flex items-center justify-between text-xs">
<button
type="button"
class="text-zinc-400 hover:text-zinc-200"
@click="viewPlans"
>Compare all plans</button>
<button
type="button"
class="text-zinc-500 hover:text-zinc-300"
@click="billing.dismissUpgradePrompt()"
>Maybe later</button>
</div>
</div>
</div>
</Teleport>
</template>
+44 -3
View File
@@ -1,16 +1,57 @@
// Typed wrapper around $fetch with the configured API base.
// Usage: const events = await useApi<EventList>('/events')
//
// Adds `Authorization: Bearer <access_token>` when the caller is signed in,
// and on a 401 transparently asks `/auth/refresh` for a new token and retries
// once. Failed refresh clears local auth state — pages can rely on the
// returned error to redirect to /login.
export async function useApi<T = unknown>(
path: string,
opts: { method?: string; body?: unknown; query?: Record<string, unknown> } = {},
): Promise<T> {
const config = useRuntimeConfig()
const base = config.public.apiBase as string
const auth = useAuth()
const request = async (token: string | null): Promise<T> => {
const headers: Record<string, string> = {}
// Let the browser set Content-Type (with the multipart boundary) when
// the body is FormData / Blob; otherwise default to JSON.
const isMultipart = typeof FormData !== 'undefined' && opts.body instanceof FormData
if (!isMultipart) headers['Content-Type'] = 'application/json'
if (token) headers.Authorization = `Bearer ${token}`
return await $fetch<T>(path, {
baseURL: base,
method: (opts.method ?? 'GET') as any,
body: opts.body,
body: opts.body as any,
query: opts.query,
headers: { 'Content-Type': 'application/json' },
headers,
credentials: 'include',
})
}
try {
return await request(auth.liveAccessToken())
} catch (err: any) {
const status = err?.response?.status ?? err?.statusCode
// 402 Payment Required — plan limit hit. Surface the backend's
// upgrade payload on a global state slot; the UpgradeModal in
// app.vue reads it and prompts the host to upgrade. We still
// rethrow so the caller can stop its own UI flow if it wants.
if (status === 402) {
const data = err?.data
if (data && data.upgrade_url) {
useBilling().showUpgradePrompt(data)
}
throw err
}
if (status !== 401) throw err
// /auth/* endpoints set the cookie themselves — never retry-refresh them.
if (path.startsWith('/auth/')) throw err
const refreshed = await auth.refresh()
if (!refreshed) throw err
return await request(auth.liveAccessToken())
}
}
+147
View File
@@ -0,0 +1,147 @@
// Auth state for the host-facing app.
//
// The access token lives only in memory (useState — Nuxt's SSR-safe wrapper).
// The refresh token lives in an HttpOnly cookie set by the API at
// `/auth/refresh` scope, so JavaScript here can never read it. On a hard
// reload we lose the access token but the cookie survives, so `bootstrap()`
// calls `/auth/refresh` to mint a fresh access token + reload the user.
interface AuthUser {
id: string
email: string
name: string
email_verified: boolean
}
interface AuthSuccess {
access_token: string
expires_at: string
user: AuthUser
}
interface AuthState {
user: AuthUser | null
accessToken: string | null
expiresAt: number | null // unix ms
bootstrapped: boolean
}
function emptyState(): AuthState {
return { user: null, accessToken: null, expiresAt: null, bootstrapped: false }
}
function apiBase(): string {
return useRuntimeConfig().public.apiBase as string
}
async function postJSON<T>(path: string, body?: unknown): Promise<T> {
return await $fetch<T>(path, {
baseURL: apiBase(),
method: 'POST',
body,
credentials: 'include',
headers: { 'Content-Type': 'application/json' },
})
}
export function useAuth() {
const state = useState<AuthState>('gg-auth', emptyState)
function setSession(s: AuthSuccess) {
state.value = {
user: s.user,
accessToken: s.access_token,
expiresAt: Date.parse(s.expires_at) || (Date.now() + 14 * 60 * 1000),
bootstrapped: true,
}
}
function clearSession() {
state.value = { ...emptyState(), bootstrapped: true }
}
async function signup(email: string, name: string, password: string, acceptTerms = false) {
return await postJSON<{ status: string }>('/auth/signup', {
email, name, password,
accept_terms: acceptTerms,
})
}
async function login(email: string, password: string) {
const s = await postJSON<AuthSuccess>('/auth/login', { email, password })
setSession(s)
return s
}
async function refresh(): Promise<boolean> {
try {
const s = await postJSON<AuthSuccess>('/auth/refresh')
setSession(s)
return true
} catch {
clearSession()
return false
}
}
async function logout() {
try {
await postJSON<void>('/auth/logout')
} catch {
// Best-effort — clear local state regardless.
}
clearSession()
}
async function verifyEmail(token: string) {
return await postJSON<{ status: string }>('/auth/verify-email', { token })
}
async function forgotPassword(email: string) {
return await postJSON<{ status: string }>('/auth/forgot-password', { email })
}
async function resetPassword(token: string, newPassword: string) {
return await postJSON<{ status: string }>('/auth/reset-password', {
token,
new_password: newPassword,
})
}
// Call on app entry / route guards. Returns true if the caller has a valid
// session by the time it resolves.
async function bootstrap(): Promise<boolean> {
if (!import.meta.client) return false
if (state.value.bootstrapped && state.value.user) return true
if (state.value.bootstrapped && !state.value.user) return false
return await refresh()
}
// Hint to useApi: returns the current token if not yet expired.
function liveAccessToken(): string | null {
if (!state.value.accessToken || !state.value.expiresAt) return null
// 5s skew to avoid sending a just-expired token.
if (Date.now() + 5000 >= state.value.expiresAt) return null
return state.value.accessToken
}
const isAuthenticated = computed(() => !!state.value.user)
const user = computed(() => state.value.user)
const bootstrapped = computed(() => state.value.bootstrapped)
return {
user,
isAuthenticated,
bootstrapped,
signup,
login,
refresh,
logout,
verifyEmail,
forgotPassword,
resetPassword,
bootstrap,
liveAccessToken,
clearSession,
}
}
+151
View File
@@ -0,0 +1,151 @@
// Billing client: fetches subscription status, kicks off checkout/portal
// flows, and owns the global "upgrade required" prompt shown by the
// 402 interceptor in useApi. State is shared across components via
// useState so the prompt can be triggered from any handler.
export interface BillingStatus {
tier: 'free' | 'pro' | 'business'
status: string
current_period_end?: string
cancel_at_period_end: boolean
limits: {
events_per_month: number
guests_per_event: number
}
usage: {
events_this_month: number
}
portal_available: boolean
}
// UpgradePrompt mirrors the 402 body the backend returns when a limit is
// hit. Shown globally as a modal until dismissed or acted upon.
export interface UpgradePrompt {
error: string
reason: string
tier: string
used: number
limit: number
upgrade_url: string
}
const FREE_DEFAULT: BillingStatus = {
tier: 'free',
status: 'active',
cancel_at_period_end: false,
limits: { events_per_month: 1, guests_per_event: 50 },
usage: { events_this_month: 0 },
portal_available: false,
}
export function useBilling() {
const status = useState<BillingStatus | null>('gg-billing-status', () => null)
const loading = useState<boolean>('gg-billing-loading', () => false)
const prompt = useState<UpgradePrompt | null>('gg-upgrade-prompt', () => null)
async function fetchStatus(): Promise<BillingStatus> {
loading.value = true
try {
const res = await useApi<BillingStatus>('/billing/status')
status.value = res
return res
} catch (e: any) {
// 401/refresh edge → caller redirected to /login by useApi. If the
// backend has billing wired but the response is malformed, fall
// back to free defaults so the page renders something usable.
status.value = FREE_DEFAULT
return FREE_DEFAULT
} finally {
loading.value = false
}
}
async function startCheckout(tier: 'pro' | 'business'): Promise<void> {
const res = await useApi<{ url: string }>('/billing/checkout-session', {
method: 'POST',
body: { tier },
})
if (import.meta.client) window.location.href = res.url
}
async function openPortal(): Promise<void> {
const res = await useApi<{ url: string }>('/billing/portal', { method: 'POST' })
if (import.meta.client) window.location.href = res.url
}
function showUpgradePrompt(info: UpgradePrompt) {
prompt.value = info
}
function dismissUpgradePrompt() {
prompt.value = null
}
return {
status,
loading,
prompt,
fetchStatus,
startCheckout,
openPortal,
showUpgradePrompt,
dismissUpgradePrompt,
}
}
// Static pricing copy. Keep in sync with internal/billing/tiers.go.
// One source of truth for the marketing-page-style cards in
// /dashboard/billing and the UpgradeModal.
export interface TierCard {
id: 'free' | 'pro' | 'business'
name: string
price: string
priceSubtitle: string
tagline: string
features: string[]
highlight?: boolean
}
export const TIER_CARDS: TierCard[] = [
{
id: 'free',
name: 'Free',
price: '$0',
priceSubtitle: 'forever',
tagline: 'Try GuestGuard with a single event.',
features: [
'1 event per month',
'Up to 50 guests per event',
'Branded email invitations',
'Real-time RSVP dashboard',
],
},
{
id: 'pro',
name: 'Pro',
price: '$49',
priceSubtitle: 'per month',
tagline: 'For active hosts running several events.',
features: [
'10 events per month',
'Up to 1,000 guests per event',
'WhatsApp + email invitations',
'CSV import + bulk send',
'Priority email support',
],
highlight: true,
},
{
id: 'business',
name: 'Business',
price: '$199',
priceSubtitle: 'per month',
tagline: 'For agencies and corporate events teams.',
features: [
'Unlimited events',
'Up to 5,000 guests per event',
'Everything in Pro',
'Signed DPA on request',
'SLA with response targets',
],
},
]
+28
View File
@@ -0,0 +1,28 @@
// Friendly error messages for API failures, with first-class handling of
// 429 (rate-limited) and 403 account-lockout responses.
export function useErrMessage(e: any, fallback = 'Something went wrong'): string {
const status: number | undefined = e?.response?.status ?? e?.statusCode
const data = e?.data
const serverMsg: string | undefined = data?.error
if (status === 429) {
const retry: number | undefined = data?.retry_after
if (typeof retry === 'number' && retry > 0) {
return `You're going too fast — try again in ${formatSeconds(retry)}.`
}
return "You're going too fast — please try again in a moment."
}
if (status === 403 && typeof serverMsg === 'string' && serverMsg.toLowerCase().includes('locked')) {
return 'Your account is locked after too many failed sign-in attempts. Reset your password to unlock it.'
}
if (serverMsg) return serverMsg
return e?.message || fallback
}
function formatSeconds(s: number): string {
if (s < 60) return `${s} second${s === 1 ? '' : 's'}`
const m = Math.ceil(s / 60)
return `${m} minute${m === 1 ? '' : 's'}`
}
+36 -6
View File
@@ -1,7 +1,12 @@
// Subscribes to /ws/events/:id and emits per-message callbacks.
//
// Auto-reconnects with exponential backoff up to 30s. Returns a cleanup
// fn the caller invokes (e.g. inside onUnmounted).
// Authenticates via short-lived ticket: before each connect we POST
// /auth/ws-ticket (bearer-authed) to mint a one-shot ticket, then pass it
// on the WS handshake as `?ticket=…`. Tickets expire ~60s after mint, so we
// always mint fresh — even on reconnects.
//
// Auto-reconnects with exponential backoff up to 30s. Returns a cleanup fn
// the caller invokes (e.g. inside onUnmounted).
interface WSMessage {
type: string
@@ -10,21 +15,46 @@ interface WSMessage {
timestamp: string
}
interface WSTicket {
ticket: string
expires_at: string
}
export function useEventWS(eventId: string, onMessage: (msg: WSMessage) => void) {
if (import.meta.server) return () => {}
const config = useRuntimeConfig()
const base = (config.public.wsBase as string) || ''
const url = `${base}/ws/events/${eventId}`
const wsBase = (config.public.wsBase as string) || ''
let ws: WebSocket | null = null
let attempt = 0
let stopped = false
let reconnectTimer: ReturnType<typeof setTimeout> | null = null
function connect() {
async function mintTicket(): Promise<string | null> {
try {
const t = await useApi<WSTicket>('/auth/ws-ticket', {
method: 'POST',
body: { event_id: eventId },
})
return t.ticket
} catch {
return null
}
}
async function connect() {
if (stopped) return
ws = new WebSocket(url)
const ticket = await mintTicket()
if (stopped) return
if (!ticket) {
// Couldn't get a ticket (likely 401 — session expired). Back off and
// retry; useApi will have already attempted refresh on its own.
const backoff = Math.min(30_000, 500 * Math.pow(2, attempt++))
reconnectTimer = setTimeout(connect, backoff)
return
}
ws = new WebSocket(`${wsBase}/ws/events/${eventId}?ticket=${encodeURIComponent(ticket)}`)
ws.onopen = () => {
attempt = 0
-43
View File
@@ -1,43 +0,0 @@
// Demo-grade host bootstrap. Real auth would replace this entirely; for now
// we upsert by email and stash the host id in localStorage.
interface User {
id: string
email: string
name: string
}
const STORAGE_KEY = 'gg.host'
export function useHost() {
const host = useState<User | null>('gg-host', () => null)
if (import.meta.client && !host.value) {
const raw = window.localStorage.getItem(STORAGE_KEY)
if (raw) {
try {
host.value = JSON.parse(raw)
} catch {
window.localStorage.removeItem(STORAGE_KEY)
}
}
}
async function bootstrap(email: string, name: string) {
const u = await useApi<User>('/users', {
method: 'POST',
body: { email, name },
})
host.value = u
if (import.meta.client) {
window.localStorage.setItem(STORAGE_KEY, JSON.stringify(u))
}
return u
}
function clear() {
host.value = null
if (import.meta.client) window.localStorage.removeItem(STORAGE_KEY)
}
return { host, bootstrap, clear }
}
+14
View File
@@ -0,0 +1,14 @@
// Route guard: bootstraps the session (server -> client cookie roundtrip,
// then optional /auth/refresh), and redirects to /login if no session.
//
// Skip on SSR: auth state lives entirely in the browser. Dashboard pages
// render their own client-side loading state while we resolve.
export default defineNuxtRouteMiddleware(async (to) => {
if (import.meta.server) return
const auth = useAuth()
const ok = await auth.bootstrap()
if (!ok) {
return navigateTo({ path: '/login', query: { redirect: to.fullPath } })
}
})
+15 -4
View File
@@ -5092,12 +5092,14 @@
}
},
"node_modules/commander": {
"version": "10.0.1",
"resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz",
"integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==",
"version": "13.1.0",
"resolved": "https://registry.npmjs.org/commander/-/commander-13.1.0.tgz",
"integrity": "sha512-/rFeCpNJQbhSZjGVwO9RFV3xPqbnERS8MmIQzCtD/zl6gpJuV/bMLuN92oG3F7d8oDEHHRrujSXNUr8fpjntKw==",
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": ">=14"
"node": ">=18"
}
},
"node_modules/common-path-prefix": {
@@ -8500,6 +8502,15 @@
"node": ">=8"
}
},
"node_modules/lambda-local/node_modules/commander": {
"version": "10.0.1",
"resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz",
"integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==",
"license": "MIT",
"engines": {
"node": ">=14"
}
},
"node_modules/lambda-local/node_modules/dotenv": {
"version": "16.6.1",
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.6.1.tgz",
+393
View File
@@ -0,0 +1,393 @@
<script setup lang="ts">
definePageMeta({ middleware: ['auth'] })
const route = useRoute()
const router = useRouter()
const billing = useBilling()
const auth = useAuth()
const config = useRuntimeConfig()
const switching = ref<'pro' | 'business' | null>(null)
const portalLoading = ref(false)
const error = ref<string | null>(null)
const toast = ref<string | null>(null)
let toastTimer: ReturnType<typeof setTimeout> | null = null
// Your data
const exporting = ref(false)
const deleteConfirmOpen = ref(false)
const deleteConfirmation = ref('')
const deleting = ref(false)
const deleteError = ref<string | null>(null)
async function exportData() {
exporting.value = true
try {
const apiBase = config.public.apiBase as string
const token = auth.liveAccessToken()
// Plain fetch (not useApi) so the response is treated as a download.
const res = await fetch(`${apiBase}/me/data-export`, {
headers: token ? { Authorization: `Bearer ${token}` } : {},
credentials: 'include',
})
if (!res.ok) throw new Error(`HTTP ${res.status}`)
const blob = await res.blob()
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = 'guestguard-data-export.json'
a.click()
URL.revokeObjectURL(url)
showToast('Export downloaded.')
} catch (e: any) {
showToast(useErrMessage(e, 'Export failed'))
} finally {
exporting.value = false
}
}
async function confirmDelete() {
deleting.value = true
deleteError.value = null
try {
await useApi('/me', { method: 'DELETE' })
// Soft-delete revoked our refresh token; clear local session and
// bounce to the marketing landing.
auth.clearSession()
await router.push('/')
} catch (e: any) {
deleteError.value = useErrMessage(e, 'Could not delete account')
} finally {
deleting.value = false
}
}
function showToast(text: string) {
toast.value = text
if (toastTimer) clearTimeout(toastTimer)
toastTimer = setTimeout(() => { toast.value = null }, 5000)
}
onMounted(async () => {
await billing.fetchStatus()
// Handle return-from-Stripe query params. ?billing=success means the
// checkout completed; Stripe also fires the webhook server-side so we
// refetch status to pick up the new tier without a hard reload.
const flag = route.query.billing
if (flag === 'success') {
showToast('Subscription updated — welcome aboard!')
// Stripe's webhook may take ~1s to land. Poll a couple of times.
for (let i = 0; i < 3; i++) {
await new Promise((r) => setTimeout(r, 1500))
await billing.fetchStatus()
if (billing.status.value?.tier !== 'free') break
}
} else if (flag === 'cancelled') {
showToast('No worries — your plan is unchanged.')
}
})
async function upgrade(tier: 'pro' | 'business') {
switching.value = tier
error.value = null
try {
await billing.startCheckout(tier)
} catch (e: any) {
if (e?.response?.status === 503) {
error.value = 'Billing isn\'t enabled on this environment yet — contact support.'
} else {
error.value = useErrMessage(e, 'Could not start checkout')
}
} finally {
switching.value = null
}
}
async function manageSubscription() {
portalLoading.value = true
error.value = null
try {
await billing.openPortal()
} catch (e: any) {
if (e?.response?.status === 503) {
error.value = 'Billing isn\'t enabled on this environment yet.'
} else {
error.value = useErrMessage(e, 'Could not open the billing portal')
}
} finally {
portalLoading.value = false
}
}
// Usage bar percentage — clamps to [0, 100] for the progress indicator.
const eventsUsagePct = computed(() => {
const s = billing.status.value
if (!s) return 0
const limit = s.limits.events_per_month
if (limit < 0) return 0 // unlimited — show empty bar
if (limit === 0) return 100
return Math.min(100, Math.round((s.usage.events_this_month / limit) * 100))
})
function formatLimit(n: number): string {
return n < 0 ? 'Unlimited' : n.toLocaleString()
}
function periodEndLabel(iso?: string): string {
if (!iso) return ''
try {
return new Date(iso).toLocaleDateString(undefined, { day: 'numeric', month: 'short', year: 'numeric' })
} catch {
return iso
}
}
</script>
<template>
<section class="space-y-6">
<div>
<NuxtLink to="/dashboard" class="mb-2 inline-block text-sm text-zinc-400 hover:text-zinc-200">
Back to dashboard
</NuxtLink>
<h1 class="text-2xl font-semibold">Billing &amp; plan</h1>
<p class="mt-1 text-sm text-zinc-400">
Change your plan, see your usage, or update your payment method.
</p>
</div>
<ClientOnly>
<!-- Current plan + usage -->
<div class="card">
<div class="flex flex-wrap items-start justify-between gap-4">
<div>
<p class="text-xs font-medium uppercase tracking-wider text-zinc-500">Current plan</p>
<div class="mt-1 flex items-baseline gap-2">
<span class="text-2xl font-semibold capitalize text-zinc-100">{{ billing.status.value?.tier || '—' }}</span>
<span
v-if="billing.status.value && billing.status.value.tier !== 'free'"
class="rounded-full bg-zinc-800 px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide text-zinc-400"
>{{ billing.status.value.status }}</span>
</div>
<p
v-if="billing.status.value?.current_period_end && billing.status.value.tier !== 'free'"
class="mt-1 text-xs text-zinc-500"
>
<template v-if="billing.status.value.cancel_at_period_end">
Cancels on {{ periodEndLabel(billing.status.value.current_period_end) }}.
</template>
<template v-else>
Renews on {{ periodEndLabel(billing.status.value.current_period_end) }}.
</template>
</p>
</div>
<button
v-if="billing.status.value?.portal_available"
type="button"
class="rounded-md border border-zinc-700 px-3 py-1.5 text-sm text-zinc-200 transition hover:border-zinc-500 hover:bg-zinc-800 disabled:opacity-50"
:disabled="portalLoading"
@click="manageSubscription"
>
{{ portalLoading ? 'Opening…' : 'Manage subscription' }}
</button>
</div>
<!-- Usage bar -->
<div class="mt-5">
<div class="mb-1.5 flex items-center justify-between text-xs">
<span class="text-zinc-300">Events this month</span>
<span class="tabular-nums text-zinc-400">
{{ billing.status.value?.usage.events_this_month ?? 0 }}
of
{{ formatLimit(billing.status.value?.limits.events_per_month ?? 0) }}
</span>
</div>
<div class="h-2 w-full overflow-hidden rounded-full bg-zinc-800">
<div
class="h-full rounded-full transition-all"
:class="eventsUsagePct >= 90 ? 'bg-amber-400' : 'bg-brand-500'"
:style="{ width: `${eventsUsagePct}%` }"
></div>
</div>
<p class="mt-1.5 text-xs text-zinc-500">
Guest cap per event: {{ formatLimit(billing.status.value?.limits.guests_per_event ?? 0) }}.
</p>
</div>
</div>
<!-- Pricing cards -->
<div class="grid gap-4 md:grid-cols-3">
<div
v-for="t in TIER_CARDS"
:key="t.id"
class="card relative flex flex-col gap-4"
:class="t.highlight ? 'border-brand-700/60 bg-brand-500/[0.04]' : ''"
>
<span
v-if="t.id === billing.status.value?.tier"
class="absolute right-3 top-3 rounded-full bg-zinc-800 px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide text-zinc-300"
>Current</span>
<span
v-else-if="t.highlight"
class="absolute right-3 top-3 rounded-full bg-brand-500 px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide text-zinc-950"
>Most popular</span>
<div>
<h3 class="text-lg font-semibold capitalize text-zinc-100">{{ t.name }}</h3>
<p class="mt-1 text-xs text-zinc-500">{{ t.tagline }}</p>
</div>
<div>
<span class="text-3xl font-semibold tabular-nums text-zinc-100">{{ t.price }}</span>
<span class="ml-1 text-xs text-zinc-500">{{ t.priceSubtitle }}</span>
</div>
<ul class="space-y-1.5 text-sm text-zinc-300">
<li v-for="f in t.features" :key="f" class="flex items-start gap-2">
<svg class="mt-0.5 h-3.5 w-3.5 shrink-0 text-brand-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
<path fill-rule="evenodd" d="M16.704 5.296a1 1 0 010 1.408l-8 8a1 1 0 01-1.408 0l-4-4a1 1 0 011.408-1.408L8 12.592l7.296-7.296a1 1 0 011.408 0z" clip-rule="evenodd" />
</svg>
<span>{{ f }}</span>
</li>
</ul>
<div class="mt-auto pt-2">
<button
v-if="t.id === billing.status.value?.tier"
type="button"
class="w-full cursor-default rounded-md border border-zinc-800 px-3 py-2 text-sm text-zinc-500"
disabled
>Current plan</button>
<button
v-else-if="t.id === 'free'"
type="button"
class="w-full cursor-default rounded-md border border-zinc-800 px-3 py-2 text-sm text-zinc-500"
disabled
>Downgrade in the billing portal</button>
<button
v-else
type="button"
class="btn-primary w-full disabled:opacity-50"
:disabled="switching === t.id"
@click="upgrade(t.id)"
>
{{ switching === t.id ? 'Opening checkout…' : `Upgrade to ${t.name}` }}
</button>
</div>
</div>
</div>
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
<p class="text-xs text-zinc-500">
Receipts and invoices are emailed automatically by Stripe.
Need to cancel? Use <a href="#" class="text-brand-400 hover:text-brand-300" @click.prevent="manageSubscription">Manage subscription</a> above.
</p>
<!-- ===== Your data ===== -->
<div class="card mt-2">
<h2 class="mb-1 text-lg font-semibold">Your data</h2>
<p class="mb-4 text-xs text-zinc-500">
Export a copy of everything we hold about you, or delete your account.
</p>
<div class="space-y-3">
<button
type="button"
class="flex w-full items-center justify-between rounded-md border border-zinc-700 bg-zinc-950 px-3 py-3 text-left transition hover:border-zinc-500 hover:bg-zinc-900 disabled:opacity-50"
:disabled="exporting"
@click="exportData"
>
<span>
<span class="block text-sm font-medium text-zinc-100">Export my data</span>
<span class="block text-xs text-zinc-500">
Download a JSON file with your events, guests, RSVPs, and account info.
</span>
</span>
<span class="text-xs text-zinc-400">{{ exporting ? '…' : '↓' }}</span>
</button>
<button
type="button"
class="flex w-full items-center justify-between rounded-md border border-red-800/40 bg-red-950/10 px-3 py-3 text-left transition hover:border-red-700 hover:bg-red-950/20"
@click="deleteConfirmOpen = true"
>
<span>
<span class="block text-sm font-medium text-red-300">Delete my account</span>
<span class="block text-xs text-red-400/70">
Soft-deleted immediately, permanently erased after 30 days. You'll be signed out everywhere.
</span>
</span>
<span class="text-xs text-red-400">→</span>
</button>
</div>
</div>
<template #fallback>
<div class="card text-sm text-zinc-500">Loading…</div>
</template>
</ClientOnly>
<!-- Delete-account confirmation -->
<Teleport to="body">
<div
v-if="deleteConfirmOpen"
class="fixed inset-0 z-50 flex items-center justify-center bg-black/60 p-4 backdrop-blur-sm"
@click.self="deleteConfirmOpen = false"
>
<div
role="alertdialog"
aria-modal="true"
aria-labelledby="del-acct-title"
class="w-full max-w-md rounded-lg border border-zinc-800 bg-zinc-900 p-5 shadow-2xl"
>
<h3 id="del-acct-title" class="mb-1 text-base font-semibold">Delete account?</h3>
<p class="mb-3 text-sm text-zinc-400">
Your account will be soft-deleted now and permanently erased
after 30 days. All your events, guests, and RSVP history go
with it. You'll be signed out from every device.
</p>
<p class="mb-3 text-xs text-zinc-500">
Type <code class="rounded bg-zinc-800 px-1 py-0.5 font-mono text-zinc-300">delete</code>
to confirm.
</p>
<input
v-model="deleteConfirmation"
type="text"
placeholder="delete"
class="input mb-3 font-mono"
autocomplete="off"
/>
<div class="flex items-center justify-end gap-2">
<button class="text-sm text-zinc-400 hover:text-zinc-200" :disabled="deleting" @click="deleteConfirmOpen = false">Cancel</button>
<button
class="rounded-md bg-red-500/90 px-3 py-1.5 text-sm font-medium text-white shadow-sm transition hover:bg-red-500 disabled:opacity-40"
:disabled="deleting || deleteConfirmation.trim().toLowerCase() !== 'delete'"
@click="confirmDelete"
>
{{ deleting ? 'Deleting…' : 'Delete forever' }}
</button>
</div>
<p v-if="deleteError" class="mt-3 text-sm text-red-400">{{ deleteError }}</p>
</div>
</div>
</Teleport>
<!-- Toast for return-from-Stripe -->
<Transition
enter-active-class="transition duration-200 ease-out"
enter-from-class="translate-y-2 opacity-0"
enter-to-class="translate-y-0 opacity-100"
leave-active-class="transition duration-200 ease-in"
leave-from-class="translate-y-0 opacity-100"
leave-to-class="translate-y-2 opacity-0"
>
<button
v-if="toast"
type="button"
class="fixed bottom-6 right-6 z-50 max-w-sm rounded-lg border border-brand-700/60 bg-brand-950/90 px-4 py-3 text-left text-sm text-brand-100 shadow-lg backdrop-blur"
@click="toast = null"
>
<span aria-hidden="true" class="mr-2"></span>{{ toast }}
</button>
</Transition>
</section>
</template>
File diff suppressed because it is too large Load Diff
+4 -2
View File
@@ -1,5 +1,8 @@
<script setup lang="ts">
const { host } = useHost()
definePageMeta({ middleware: ['auth'] })
const auth = useAuth()
const host = auth.user
const name = ref('')
const slug = ref('')
@@ -17,7 +20,6 @@ async function submit() {
const created = await useApi<{ id: string }>('/events', {
method: 'POST',
body: {
host_id: host.value.id,
name: name.value,
slug: slug.value,
event_date: new Date(eventDate.value).toISOString(),
+15 -43
View File
@@ -1,4 +1,6 @@
<script setup lang="ts">
definePageMeta({ middleware: ['auth'] })
interface EventSummary {
id: string
name: string
@@ -13,40 +15,28 @@ interface EventsResponse {
events: EventSummary[]
}
const { host, bootstrap } = useHost()
const email = ref('')
const name = ref('')
const bootstrapping = ref(false)
const bootstrapError = ref<string | null>(null)
async function onBootstrap() {
bootstrapError.value = null
bootstrapping.value = true
try {
await bootstrap(email.value, name.value)
} catch (e: any) {
bootstrapError.value = e?.data?.error || e?.message || 'Failed to bootstrap'
} finally {
bootstrapping.value = false
}
}
const auth = useAuth()
const events = ref<EventSummary[]>([])
const loadingEvents = ref(false)
const loadError = ref<string | null>(null)
async function loadEvents() {
if (!host.value) return
if (!auth.user.value) return
loadingEvents.value = true
loadError.value = null
try {
const res = await useApi<EventsResponse>('/events', { query: { host_id: host.value.id } })
// host is derived server-side from the session — no query param needed.
const res = await useApi<EventsResponse>('/events')
events.value = res.events
} catch (e: any) {
loadError.value = e?.data?.error || e?.message || 'Failed to load events'
} finally {
loadingEvents.value = false
}
}
watch(host, loadEvents, { immediate: true })
watch(() => auth.user.value, loadEvents, { immediate: true })
function fmtDate(iso: string) {
try { return new Date(iso).toLocaleString() } catch { return iso }
@@ -55,40 +45,22 @@ function fmtDate(iso: string) {
<template>
<section>
<!--
The dashboard is auth-gated by a localStorage-backed host. Rendering
that conditional on the server (where there's no localStorage) and
then again on the client (where there is) causes a hydration
mismatch that leaves the layout stuck at the bootstrap card's width
after a hard refresh. Skipping SSR for this block fixes both the
flash and the layout shrink.
-->
<ClientOnly>
<div v-if="!host" class="card max-w-md">
<h1 class="mb-2 text-xl font-semibold">Get started</h1>
<p class="mb-4 text-sm text-zinc-400">
Demo bootstrap enter an email + name to provision a host. We don't store passwords.
</p>
<label class="label">Email</label>
<input v-model="email" type="email" class="input mb-3" placeholder="you@example.com" />
<label class="label">Name</label>
<input v-model="name" type="text" class="input mb-4" placeholder="Your name" />
<button class="btn-primary w-full" :disabled="bootstrapping || !email || !name" @click="onBootstrap">
{{ bootstrapping ? 'Setting up' : 'Continue' }}
</button>
<p v-if="bootstrapError" class="mt-3 text-sm text-red-400">{{ bootstrapError }}</p>
<div v-if="!auth.bootstrapped.value || !auth.user.value" class="text-sm text-zinc-500">
Loading dashboard
</div>
<div v-else>
<div class="mb-6 flex items-center justify-between">
<div>
<h1 class="text-2xl font-semibold">Your events</h1>
<p class="text-sm text-zinc-400">Signed in as {{ host.name }} ({{ host.email }})</p>
<p class="text-sm text-zinc-400">Signed in as {{ auth.user.value.name }} ({{ auth.user.value.email }})</p>
</div>
<NuxtLink to="/dashboard/events/new" class="btn-primary">New event</NuxtLink>
</div>
<div v-if="loadingEvents" class="text-sm text-zinc-500">Loading</div>
<div v-else-if="loadError" class="card text-sm text-red-400">{{ loadError }}</div>
<div v-else-if="events.length === 0" class="card text-sm text-zinc-400">
No events yet. Create one to get started.
</div>
+49
View File
@@ -0,0 +1,49 @@
<script setup lang="ts">
const auth = useAuth()
const email = ref('')
const submitting = ref(false)
const sent = ref(false)
const error = ref<string | null>(null)
async function submit() {
error.value = null
submitting.value = true
try {
await auth.forgotPassword(email.value)
sent.value = true
} catch (e: any) {
error.value = useErrMessage(e, 'Request failed')
} finally {
submitting.value = false
}
}
</script>
<template>
<section class="mx-auto max-w-md py-12">
<h1 class="mb-2 text-2xl font-semibold">Forgot your password?</h1>
<p class="mb-6 text-sm text-zinc-400">
Enter your email and we'll send a reset link if there's an account on file.
</p>
<div v-if="sent" class="card text-sm">
<p class="mb-2 font-medium text-brand-300">Check your inbox.</p>
<p class="text-zinc-400">
If <span class="text-zinc-200">{{ email }}</span> is registered, a reset link is on its way.
</p>
<NuxtLink to="/login" class="btn-ghost mt-4 w-full">Back to sign in</NuxtLink>
</div>
<form v-else class="card space-y-4" @submit.prevent="submit">
<div>
<label class="label">Email</label>
<input v-model="email" type="email" class="input" autocomplete="email" required />
</div>
<button class="btn-primary w-full" :disabled="submitting || !email">
{{ submitting ? 'Sending…' : 'Send reset link' }}
</button>
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
</form>
</section>
</template>
+53
View File
@@ -0,0 +1,53 @@
<script setup lang="ts">
const auth = useAuth()
const route = useRoute()
const email = ref('')
const password = ref('')
const submitting = ref(false)
const error = ref<string | null>(null)
async function submit() {
error.value = null
submitting.value = true
try {
await auth.login(email.value, password.value)
const redirect = (route.query.redirect as string) || '/dashboard'
await navigateTo(redirect)
} catch (e: any) {
error.value = useErrMessage(e, 'Login failed')
} finally {
submitting.value = false
}
}
</script>
<template>
<section class="mx-auto max-w-md py-12">
<h1 class="mb-2 text-2xl font-semibold">Sign in</h1>
<p class="mb-6 text-sm text-zinc-400">Welcome back. Sign in to manage your events.</p>
<form class="card space-y-4" @submit.prevent="submit">
<div>
<label class="label">Email</label>
<input v-model="email" type="email" class="input" autocomplete="email" required />
</div>
<div>
<label class="label">Password</label>
<input v-model="password" type="password" class="input" autocomplete="current-password" required />
<div class="mt-1 text-right text-xs">
<NuxtLink to="/forgot-password" class="text-zinc-400 hover:text-zinc-200">Forgot password?</NuxtLink>
</div>
</div>
<button class="btn-primary w-full" :disabled="submitting || !email || !password">
{{ submitting ? 'Signing in…' : 'Sign in' }}
</button>
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
</form>
<p class="mt-6 text-center text-sm text-zinc-400">
Don't have an account?
<NuxtLink to="/signup" class="text-brand-400 hover:text-brand-300">Sign up</NuxtLink>
</p>
</section>
</template>
+51
View File
@@ -0,0 +1,51 @@
<script setup lang="ts">
useHead({ title: 'Privacy policy · GuestGuard' })
</script>
<template>
<article class="prose prose-invert mx-auto max-w-2xl py-8 text-zinc-300">
<h1 class="text-2xl font-semibold text-zinc-100">Privacy policy</h1>
<p class="text-sm text-zinc-500">Last updated: <strong>placeholder pending legal review</strong></p>
<p>
This page is a placeholder while we have proper privacy copy
reviewed by a lawyer. The substance below reflects how the
product actually handles data today the final wording will
replace this page before public launch.
</p>
<h2 class="mt-6 text-lg font-semibold text-zinc-100">What we collect</h2>
<ul class="list-disc pl-6">
<li><strong>Host account</strong>: email, name, hashed password.</li>
<li><strong>Guest list</strong>: names, emails, phone numbers only what
the host enters or imports.</li>
<li><strong>RSVP responses</strong>: the answer the guest sends back.</li>
<li><strong>Access logs</strong>: IP, device fingerprint, and a fraud
risk score, for each invitation-link open. Used to flag suspicious
access (e.g. someone forwarded the link).</li>
<li><strong>Billing</strong>: handled by Stripe we store only the
Stripe customer + subscription IDs locally, never card details.</li>
</ul>
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Your rights</h2>
<ul class="list-disc pl-6">
<li><strong>Export</strong>: download a full JSON dump of your data via
Settings "Export my data".</li>
<li><strong>Delete</strong>: delete your account via Settings "Delete
account". Your row is soft-deleted immediately and hard-deleted 30
days later (kept briefly in case of accidental clicks).</li>
<li><strong>Question</strong>: email
<a href="mailto:privacy@gg.k4scloud.com" class="text-brand-400">privacy@gg.k4scloud.com</a>.</li>
</ul>
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Sub-processors</h2>
<p>We use these third parties to run the service:</p>
<ul class="list-disc pl-6">
<li><strong>Stripe</strong> billing</li>
<li><strong>Resend</strong> transactional email delivery</li>
<li><strong>AWS S3</strong> encrypted database backups</li>
</ul>
<NuxtLink to="/" class="mt-8 inline-block text-sm text-brand-400 hover:text-brand-300"> Back home</NuxtLink>
</article>
</template>
+73
View File
@@ -0,0 +1,73 @@
<script setup lang="ts">
const auth = useAuth()
const route = useRoute()
const password = ref('')
const confirm = ref('')
const submitting = ref(false)
const done = ref(false)
const error = ref<string | null>(null)
const token = computed(() => String(route.params.token || ''))
async function submit() {
error.value = null
if (password.value !== confirm.value) {
error.value = 'Passwords do not match.'
return
}
submitting.value = true
try {
await auth.resetPassword(token.value, password.value)
done.value = true
} catch (e: any) {
error.value = e?.data?.error || e?.message || 'Reset failed'
} finally {
submitting.value = false
}
}
</script>
<template>
<section class="mx-auto max-w-md py-12">
<h1 class="mb-6 text-2xl font-semibold">Choose a new password</h1>
<div v-if="done" class="card text-sm">
<p class="mb-2 font-medium text-brand-300">Password updated.</p>
<p class="mb-4 text-zinc-400">All previous sessions have been signed out. Sign in to continue.</p>
<NuxtLink to="/login" class="btn-primary w-full">Sign in</NuxtLink>
</div>
<form v-else class="card space-y-4" @submit.prevent="submit">
<div>
<label class="label">New password</label>
<input
v-model="password"
type="password"
class="input"
autocomplete="new-password"
minlength="8"
maxlength="72"
required
/>
<p class="mt-1 text-xs text-zinc-500">At least 8 characters.</p>
</div>
<div>
<label class="label">Confirm password</label>
<input
v-model="confirm"
type="password"
class="input"
autocomplete="new-password"
minlength="8"
maxlength="72"
required
/>
</div>
<button class="btn-primary w-full" :disabled="submitting || password.length < 8 || !confirm">
{{ submitting ? 'Updating…' : 'Update password' }}
</button>
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
</form>
</section>
</template>
+90
View File
@@ -0,0 +1,90 @@
<script setup lang="ts">
const auth = useAuth()
const email = ref('')
const name = ref('')
const password = ref('')
const acceptTerms = ref(false)
const submitting = ref(false)
const error = ref<string | null>(null)
const sent = ref(false)
async function submit() {
error.value = null
submitting.value = true
try {
await auth.signup(email.value, name.value, password.value, acceptTerms.value)
sent.value = true
} catch (e: any) {
error.value = useErrMessage(e, 'Sign-up failed')
} finally {
submitting.value = false
}
}
</script>
<template>
<section class="mx-auto max-w-md py-12">
<h1 class="mb-2 text-2xl font-semibold">Create your account</h1>
<p class="mb-6 text-sm text-zinc-400">Start managing your event guest lists in minutes.</p>
<div v-if="sent" class="card text-sm">
<p class="mb-2 font-medium text-brand-300">Check your inbox.</p>
<p class="text-zinc-400">
If <span class="text-zinc-200">{{ email }}</span> is reachable, we've sent a verification link.
Click it to finish setting up your account.
</p>
<NuxtLink to="/login" class="btn-ghost mt-4 w-full">Back to sign in</NuxtLink>
</div>
<form v-else class="card space-y-4" @submit.prevent="submit">
<div>
<label class="label">Name</label>
<input v-model="name" type="text" class="input" autocomplete="name" required />
</div>
<div>
<label class="label">Email</label>
<input v-model="email" type="email" class="input" autocomplete="email" required />
</div>
<div>
<label class="label">Password</label>
<input
v-model="password"
type="password"
class="input"
autocomplete="new-password"
minlength="8"
maxlength="72"
required
/>
<p class="mt-1 text-xs text-zinc-500">At least 8 characters.</p>
</div>
<label class="flex cursor-pointer items-start gap-2 text-xs text-zinc-400">
<input
v-model="acceptTerms"
type="checkbox"
class="mt-0.5 h-4 w-4 cursor-pointer accent-brand-500"
required
/>
<span>
I agree to GuestGuard's
<NuxtLink to="/terms" target="_blank" class="text-brand-400 hover:text-brand-300">Terms of Service</NuxtLink>
and
<NuxtLink to="/privacy" target="_blank" class="text-brand-400 hover:text-brand-300">Privacy Policy</NuxtLink>.
</span>
</label>
<button
class="btn-primary w-full"
:disabled="submitting || !email || !name || password.length < 8 || !acceptTerms"
>
{{ submitting ? 'Creating…' : 'Create account' }}
</button>
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
</form>
<p class="mt-6 text-center text-sm text-zinc-400">
Already have an account?
<NuxtLink to="/login" class="text-brand-400 hover:text-brand-300">Sign in</NuxtLink>
</p>
</section>
</template>
+50
View File
@@ -0,0 +1,50 @@
<script setup lang="ts">
useHead({ title: 'Terms of service · GuestGuard' })
</script>
<template>
<article class="prose prose-invert mx-auto max-w-2xl py-8 text-zinc-300">
<h1 class="text-2xl font-semibold text-zinc-100">Terms of service</h1>
<p class="text-sm text-zinc-500">Last updated: <strong>placeholder pending legal review</strong></p>
<p>
Placeholder copy while a lawyer drafts the real document. The points
below capture the substance of how we'd like the relationship to
work — the binding version will replace this page before public
launch.
</p>
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Summary</h2>
<ul class="list-disc pl-6">
<li>You're responsible for your guest list make sure you have
permission to message the people on it.</li>
<li>We're responsible for keeping the service running, your data
backed up, and our handling of it transparent (see the privacy
page).</li>
<li>Either of us can end the relationship — you by deleting your
account, us with reasonable notice if we have to wind down.</li>
<li>The service is provided "as-is" — we don't promise zero downtime
or that no spam-filter on earth will misfilter your invitations.
We try hard, but we're not insurers.</li>
</ul>
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Acceptable use</h2>
<ul class="list-disc pl-6">
<li>Don't use GuestGuard to send unsolicited bulk mail (spam).</li>
<li>Don't use it to harass, deceive, or impersonate.</li>
<li>Don't poke at the security of the service if you find a bug,
please tell us at
<a href="mailto:security@gg.k4scloud.com" class="text-brand-400">security@gg.k4scloud.com</a>.</li>
</ul>
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Payment</h2>
<p>
Subscriptions auto-renew until cancelled. Cancel any time from
Billing Manage subscription. Refunds: we'll consider them case by
case email
<a href="mailto:support@gg.k4scloud.com" class="text-brand-400">support@gg.k4scloud.com</a>.
</p>
<NuxtLink to="/" class="mt-8 inline-block text-sm text-brand-400 hover:text-brand-300"> Back home</NuxtLink>
</article>
</template>
+69
View File
@@ -0,0 +1,69 @@
<script setup lang="ts">
const route = useRoute()
const config = useRuntimeConfig()
const apiBase = config.public.apiBase as string
const token = computed(() => String(route.params.token || ''))
const status = ref<'loading' | 'ready' | 'done' | 'error'>('loading')
const email = ref('')
const error = ref<string | null>(null)
const submitting = ref(false)
onMounted(async () => {
try {
const r = await $fetch<{ email: string }>(`/unsubscribe/${token.value}`, { baseURL: apiBase })
email.value = r.email
status.value = 'ready'
} catch (e: any) {
status.value = 'error'
error.value = e?.data?.error || 'This link is invalid or has expired.'
}
})
async function confirm() {
submitting.value = true
error.value = null
try {
await $fetch(`/unsubscribe/${token.value}`, { baseURL: apiBase, method: 'POST' })
status.value = 'done'
} catch (e: any) {
error.value = e?.data?.error || 'Something went wrong — please try again.'
} finally {
submitting.value = false
}
}
</script>
<template>
<section class="mx-auto max-w-md py-12">
<h1 class="mb-6 text-2xl font-semibold">Unsubscribe</h1>
<div v-if="status === 'loading'" class="card text-sm text-zinc-400">Loading</div>
<div v-else-if="status === 'ready'" class="card space-y-4 text-sm">
<p>You'll stop receiving GuestGuard emails sent to:</p>
<p class="font-mono text-zinc-200">{{ email }}</p>
<p class="text-zinc-400">
This includes RSVPs, reminders, and host messages. Account-related emails
(security, password resets) will still reach you.
</p>
<button class="btn-primary w-full" :disabled="submitting" @click="confirm">
{{ submitting ? 'Unsubscribing' : 'Unsubscribe me' }}
</button>
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
</div>
<div v-else-if="status === 'done'" class="card text-sm">
<p class="mb-2 font-medium text-brand-300">Done.</p>
<p class="text-zinc-400">
We've added <span class="text-zinc-200">{{ email }}</span> to our suppression list.
Future emails to that address will be silently dropped.
</p>
</div>
<div v-else class="card text-sm">
<p class="mb-2 font-medium text-red-400">Can't unsubscribe</p>
<p class="text-zinc-400">{{ error }}</p>
</div>
</section>
</template>
+44
View File
@@ -0,0 +1,44 @@
<script setup lang="ts">
const auth = useAuth()
const route = useRoute()
type Status = 'pending' | 'success' | 'error'
const status = ref<Status>('pending')
const error = ref<string | null>(null)
onMounted(async () => {
const token = route.query.token
if (typeof token !== 'string' || !token) {
status.value = 'error'
error.value = 'Missing verification token.'
return
}
try {
await auth.verifyEmail(token)
status.value = 'success'
} catch (e: any) {
status.value = 'error'
error.value = e?.data?.error || e?.message || 'Verification failed.'
}
})
</script>
<template>
<section class="mx-auto max-w-md py-12">
<h1 class="mb-6 text-2xl font-semibold">Verify your email</h1>
<div class="card text-sm">
<p v-if="status === 'pending'" class="text-zinc-400">Verifying</p>
<template v-else-if="status === 'success'">
<p class="mb-3 font-medium text-brand-300">Email verified.</p>
<p class="mb-4 text-zinc-400">You can now sign in to your account.</p>
<NuxtLink to="/login" class="btn-primary w-full">Sign in</NuxtLink>
</template>
<template v-else>
<p class="mb-3 font-medium text-red-400">We couldn't verify that link.</p>
<p class="mb-4 text-zinc-400">{{ error }}</p>
<NuxtLink to="/login" class="btn-ghost w-full">Back to sign in</NuxtLink>
</template>
</div>
</section>
</template>
+26 -2
View File
@@ -4,11 +4,13 @@ go 1.26.2
require (
github.com/coder/websocket v1.8.14
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.9.2
github.com/nats-io/nats.go v1.52.0
github.com/testcontainers/testcontainers-go v0.42.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0
golang.org/x/crypto v0.49.0
google.golang.org/grpc v1.81.0
google.golang.org/protobuf v1.36.11
)
@@ -17,6 +19,22 @@ require (
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/alicebob/miniredis/v2 v2.38.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
github.com/aws/smithy-go v1.25.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
@@ -33,6 +51,7 @@ require (
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
@@ -52,24 +71,29 @@ require (
github.com/nats-io/nuid v1.0.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/redis/go-redis/v9 v9.19.0 // indirect
github.com/shirou/gopsutil/v4 v4.26.3 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/stripe/stripe-go/v82 v82.5.1 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/twilio/twilio-go v1.30.9 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
golang.org/x/crypto v0.49.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/text v0.37.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+73
View File
@@ -6,6 +6,38 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw=
github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4 h1:X/PtmuX/EwPivJ9lHCf3Auo8AktdNc4a9ury4zmGPC4=
github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4/go.mod h1:l5cTwZSX9kzxDHz9IpgZC0XIJ/cc43JL6hZzCd0iTwI=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
@@ -44,6 +76,10 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -101,10 +137,14 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k=
github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/shirou/gopsutil/v4 v4.26.3 h1:2ESdQt90yU3oXF/CdOlRCJxrP+Am1aBYubTMTfxJ1qc=
@@ -118,6 +158,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stripe/stripe-go/v82 v82.5.1 h1:05q6ZDKoe8PLMpQV072obF74HCgP4XJeJYoNuRSX2+8=
github.com/stripe/stripe-go/v82 v82.5.1/go.mod h1:majCQX6AfObAvJiHraPi/5udwHi4ojRvJnnxckvHrX8=
github.com/testcontainers/testcontainers-go v0.42.0 h1:He3IhTzTZOygSXLJPMX7n44XtK+qhjat1nI9cneBbUY=
github.com/testcontainers/testcontainers-go v0.42.0/go.mod h1:vZjdY1YmUA1qEForxOIOazfsrdyORJAbhi0bp8plN30=
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 h1:GCbb1ndrF7OTDiIvxXyItaDab4qkzTFJ48LKFdM7EIo=
@@ -126,6 +168,11 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/twilio/twilio-go v1.30.9 h1:4W4GEV2q0sLQ9xsr1N/97JQlt0c82hZ0ij4qTErstv8=
github.com/twilio/twilio-go v1.30.9/go.mod h1:QbitvbvtkV77Jn4BABAKVmxabYSjMyQG4tHey9gfPqg=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
@@ -142,22 +189,48 @@ go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfC
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
+5 -8
View File
@@ -1,12 +1,10 @@
package api
import (
"errors"
"net/http"
"sort"
"time"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
@@ -42,16 +40,15 @@ type activityItem struct {
// for an event, sorted newest first. Frontends use this on dashboard mount
// to backfill the live monitor with history.
func (h *activityHandler) list(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, err := h.events.Get(r.Context(), eventID); err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load event")
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
+557
View File
@@ -0,0 +1,557 @@
package api
import (
"context"
"encoding/json"
"errors"
"log/slog"
"net/http"
"net/mail"
"net/url"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/auth"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/ratelimit"
"github.com/alchemistkay/guestguard/internal/storage"
)
const refreshCookieName = "gg_refresh"
type authHandler struct {
logger *slog.Logger
users *storage.UserRepo
verifications *storage.EmailVerificationRepo
resets *storage.PasswordResetRepo
refreshes *storage.RefreshTokenRepo
hasher *auth.PasswordHasher
signer *auth.JWTSigner
emails auth.EmailSender
lockout *auth.LockoutTracker
limiter *ratelimit.Limiter
publicBaseURL string
emailVerificationTTL time.Duration
passwordResetTTL time.Duration
refreshTTL time.Duration
cookieDomain string
cookieSecure bool
}
type authHandlerDeps struct {
Logger *slog.Logger
Users *storage.UserRepo
Verifications *storage.EmailVerificationRepo
Resets *storage.PasswordResetRepo
Refreshes *storage.RefreshTokenRepo
Hasher *auth.PasswordHasher
Signer *auth.JWTSigner
Emails auth.EmailSender
Lockout *auth.LockoutTracker
Limiter *ratelimit.Limiter
PublicBaseURL string
EmailVerificationTTL time.Duration
PasswordResetTTL time.Duration
RefreshTTL time.Duration
CookieDomain string
CookieSecure bool
}
func newAuthHandler(d authHandlerDeps) *authHandler {
return &authHandler{
logger: d.Logger,
users: d.Users,
verifications: d.Verifications,
resets: d.Resets,
refreshes: d.Refreshes,
hasher: d.Hasher,
signer: d.Signer,
emails: d.Emails,
lockout: d.Lockout,
limiter: d.Limiter,
publicBaseURL: strings.TrimRight(d.PublicBaseURL, "/"),
emailVerificationTTL: d.EmailVerificationTTL,
passwordResetTTL: d.PasswordResetTTL,
refreshTTL: d.RefreshTTL,
cookieDomain: d.CookieDomain,
cookieSecure: d.CookieSecure,
}
}
// --- request/response types ---
type signupRequest struct {
Email string `json:"email"`
Name string `json:"name"`
Password string `json:"password"`
AcceptTerms bool `json:"accept_terms"`
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type verifyEmailRequest struct {
Token string `json:"token"`
}
type forgotPasswordRequest struct {
Email string `json:"email"`
}
type resetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type authSuccess struct {
AccessToken string `json:"access_token"`
ExpiresAt time.Time `json:"expires_at"`
User *domain.User `json:"user"`
}
// --- handlers ---
// POST /auth/signup
func (h *authHandler) signup(w http.ResponseWriter, r *http.Request) {
var req signupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if _, err := mail.ParseAddress(req.Email); err != nil {
writeError(w, http.StatusBadRequest, "email is invalid")
return
}
if strings.TrimSpace(req.Name) == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
hash, err := h.hasher.Hash(req.Password)
if err != nil {
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
h.logger.Error("hash password", "err", err)
writeError(w, http.StatusInternalServerError, "failed to create user")
return
}
u, err := h.users.Create(r.Context(), storage.CreateUserParams{
Email: req.Email,
Name: req.Name,
PasswordHash: hash,
AcceptTerms: req.AcceptTerms,
})
if err != nil {
if errors.Is(err, domain.ErrEmailTaken) {
// Don't leak which addresses are registered. Still return 201 and
// trigger a "if-you-already-have-an-account" email asynchronously
// (skipped for the stub). On real auth this should send a "you
// tried to sign up again, here's a reset link" email.
h.logger.Info("signup attempted with existing email", "email", req.Email)
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
return
}
h.logger.Error("create user", "err", err)
writeError(w, http.StatusInternalServerError, "failed to create user")
return
}
if err := h.sendVerificationEmail(r.Context(), u); err != nil {
h.logger.Error("send verification email", "err", err, "user_id", u.ID)
// Don't fail the signup — user can request a resend.
}
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
}
// POST /auth/login
func (h *authHandler) login(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if req.Email == "" || req.Password == "" {
writeError(w, http.StatusBadRequest, "email and password required")
return
}
// Per-(IP + email) sliding-window — 10 per 5 minutes per the plan.
if !h.checkRate(w, r, "login", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
10, 5*time.Minute) {
return
}
u, err := h.users.GetByEmail(r.Context(), req.Email)
if err != nil || u.PasswordHash == "" {
_, _ = h.lockout.RecordFailure(r.Context(), req.Email, nil)
writeError(w, http.StatusUnauthorized, "invalid email or password")
return
}
// If the account is already locked, reject before doing a bcrypt compare.
locked, _ := h.lockout.IsLocked(r.Context(), u.ID)
if locked {
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
return
}
if err := h.hasher.Verify(u.PasswordHash, req.Password); err != nil {
locked, _ := h.lockout.RecordFailure(r.Context(), req.Email, &u.ID)
if locked {
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
return
}
writeError(w, http.StatusUnauthorized, "invalid email or password")
return
}
if !u.EmailVerified {
writeError(w, http.StatusForbidden, "email not verified")
return
}
h.lockout.ClearOnSuccess(r.Context(), req.Email)
if err := h.issueSession(w, r, u); err != nil {
h.logger.Error("issue session", "err", err, "user_id", u.ID)
writeError(w, http.StatusInternalServerError, "failed to start session")
return
}
}
// checkRate consults the limiter (when one is configured) and writes a 429
// response if the budget is exhausted. Returns false if the caller should
// stop handling the request.
func (h *authHandler) checkRate(w http.ResponseWriter, r *http.Request, name, key string, limit int, window time.Duration) bool {
if h.limiter == nil || key == "" {
return true
}
res, err := h.limiter.Allow(r.Context(), name, key, limit, window)
if err != nil {
h.logger.Warn("ratelimit error (failing open)", "rule", name, "err", err)
return true
}
if !res.Allowed {
retry := int(res.RetryAfter.Round(time.Second).Seconds())
if retry < 1 {
retry = 1
}
w.Header().Set("Retry-After", strconv.Itoa(retry))
writeJSON(w, http.StatusTooManyRequests, map[string]any{
"error": "rate limit exceeded",
"retry_after": retry,
})
return false
}
return true
}
// POST /auth/refresh
func (h *authHandler) refresh(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(refreshCookieName)
if err != nil || cookie.Value == "" {
writeError(w, http.StatusUnauthorized, "missing refresh token")
return
}
oldHash := auth.HashOpaque(cookie.Value)
existing, err := h.refreshes.Get(r.Context(), oldHash)
if err != nil {
if errors.Is(err, domain.ErrAuthTokenNotFound) {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "invalid refresh token")
return
}
h.logger.Error("lookup refresh", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
if existing.RevokedAt != nil {
// Replay of a revoked token. Revoke the family.
_ = h.refreshes.RevokeAllForUser(r.Context(), existing.UserID)
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "refresh token reused")
return
}
if time.Now().After(existing.ExpiresAt) {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "refresh token expired")
return
}
u, err := h.users.GetByID(r.Context(), existing.UserID)
if err != nil {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "user not found")
return
}
newRaw, newHash, err := auth.NewOpaqueToken()
if err != nil {
h.logger.Error("mint refresh", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
exp := time.Now().Add(h.refreshTTL)
if err := h.refreshes.Rotate(r.Context(), oldHash, storage.CreateRefreshTokenParams{
Hash: newHash,
UserID: u.ID,
ExpiresAt: exp,
UserAgent: r.UserAgent(),
IPAddress: clientIP(r),
}); err != nil {
if errors.Is(err, domain.ErrRefreshTokenRevoked) {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "refresh token reused")
return
}
h.logger.Error("rotate refresh", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
if err != nil {
h.logger.Error("sign access", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
h.setRefreshCookie(w, newRaw, exp)
writeJSON(w, http.StatusOK, authSuccess{
AccessToken: access,
ExpiresAt: accessExp,
User: u,
})
}
// POST /auth/logout
func (h *authHandler) logout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(refreshCookieName)
if err == nil && cookie.Value != "" {
_ = h.refreshes.Revoke(r.Context(), auth.HashOpaque(cookie.Value))
}
h.clearRefreshCookie(w)
w.WriteHeader(http.StatusNoContent)
}
// POST /auth/verify-email
func (h *authHandler) verifyEmail(w http.ResponseWriter, r *http.Request) {
var req verifyEmailRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
writeError(w, http.StatusBadRequest, "token required")
return
}
uid, err := h.verifications.Consume(r.Context(), auth.HashOpaque(req.Token))
if err != nil {
switch {
case errors.Is(err, domain.ErrAuthTokenNotFound):
writeError(w, http.StatusBadRequest, "invalid token")
case errors.Is(err, domain.ErrAuthTokenConsumed):
writeError(w, http.StatusBadRequest, "token already used")
case errors.Is(err, domain.ErrAuthTokenExpired):
writeError(w, http.StatusBadRequest, "token expired")
default:
h.logger.Error("consume verification", "err", err)
writeError(w, http.StatusInternalServerError, "verification failed")
}
return
}
if err := h.users.MarkEmailVerified(r.Context(), uid); err != nil {
h.logger.Error("mark verified", "err", err, "user_id", uid)
writeError(w, http.StatusInternalServerError, "verification failed")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "verified"})
}
// POST /auth/forgot-password
func (h *authHandler) forgotPassword(w http.ResponseWriter, r *http.Request) {
var req forgotPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if !h.checkRate(w, r, "forgot_password", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
3, time.Hour) {
return
}
// Always respond 202 to avoid leaking whether the email exists.
defer func() { writeJSON(w, http.StatusAccepted, map[string]string{"status": "if_known_email_sent"}) }()
u, err := h.users.GetByEmail(r.Context(), req.Email)
if err != nil {
return
}
raw, hash, err := auth.NewOpaqueToken()
if err != nil {
h.logger.Error("mint reset", "err", err)
return
}
exp := time.Now().Add(h.passwordResetTTL)
if err := h.resets.Create(r.Context(), u.ID, hash, exp); err != nil {
h.logger.Error("persist reset", "err", err)
return
}
link := h.publicBaseURL + "/reset-password/" + url.PathEscape(raw)
if err := h.emails.SendPasswordReset(r.Context(), u.Email, u.Name, link); err != nil {
h.logger.Error("send reset email", "err", err)
}
}
// POST /auth/reset-password
func (h *authHandler) resetPassword(w http.ResponseWriter, r *http.Request) {
var req resetPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
writeError(w, http.StatusBadRequest, "token and new_password required")
return
}
newHash, err := h.hasher.Hash(req.NewPassword)
if err != nil {
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
h.logger.Error("hash password", "err", err)
writeError(w, http.StatusInternalServerError, "reset failed")
return
}
uid, err := h.resets.Consume(r.Context(), auth.HashOpaque(req.Token))
if err != nil {
switch {
case errors.Is(err, domain.ErrAuthTokenNotFound):
writeError(w, http.StatusBadRequest, "invalid token")
case errors.Is(err, domain.ErrAuthTokenConsumed):
writeError(w, http.StatusBadRequest, "token already used")
case errors.Is(err, domain.ErrAuthTokenExpired):
writeError(w, http.StatusBadRequest, "token expired")
default:
h.logger.Error("consume reset", "err", err)
writeError(w, http.StatusInternalServerError, "reset failed")
}
return
}
if err := h.users.UpdatePasswordHash(r.Context(), uid, newHash); err != nil {
h.logger.Error("update password", "err", err, "user_id", uid)
writeError(w, http.StatusInternalServerError, "reset failed")
return
}
// Invalidate all existing sessions.
_ = h.refreshes.RevokeAllForUser(r.Context(), uid)
// Resetting the password is the canonical "unlock" path for the
// account lockout that triggers after repeated bad-credential attempts.
if u, err := h.users.GetByID(r.Context(), uid); err == nil {
_ = h.lockout.ClearForUser(r.Context(), uid, u.Email)
}
writeJSON(w, http.StatusOK, map[string]string{"status": "password_reset"})
}
// --- helpers ---
func (h *authHandler) sendVerificationEmail(ctx context.Context, u *domain.User) error {
raw, hash, err := auth.NewOpaqueToken()
if err != nil {
return err
}
if err := h.verifications.Create(ctx, u.ID, hash, time.Now().Add(h.emailVerificationTTL)); err != nil {
return err
}
link := h.publicBaseURL + "/verify-email?token=" + url.QueryEscape(raw)
return h.emails.SendVerification(ctx, u.Email, u.Name, link)
}
func (h *authHandler) issueSession(w http.ResponseWriter, r *http.Request, u *domain.User) error {
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
if err != nil {
return err
}
raw, hash, err := auth.NewOpaqueToken()
if err != nil {
return err
}
refreshExp := time.Now().Add(h.refreshTTL)
if err := h.refreshes.Create(r.Context(), storage.CreateRefreshTokenParams{
Hash: hash,
UserID: u.ID,
ExpiresAt: refreshExp,
UserAgent: r.UserAgent(),
IPAddress: clientIP(r),
}); err != nil {
return err
}
h.setRefreshCookie(w, raw, refreshExp)
writeJSON(w, http.StatusOK, authSuccess{
AccessToken: access,
ExpiresAt: accessExp,
User: u,
})
return nil
}
func (h *authHandler) setRefreshCookie(w http.ResponseWriter, value string, expires time.Time) {
http.SetCookie(w, &http.Cookie{
Name: refreshCookieName,
Value: value,
Path: "/auth",
Domain: h.cookieDomain,
Expires: expires,
MaxAge: int(time.Until(expires).Seconds()),
HttpOnly: true,
Secure: h.cookieSecure,
SameSite: http.SameSiteLaxMode,
})
}
func (h *authHandler) clearRefreshCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: refreshCookieName,
Value: "",
Path: "/auth",
Domain: h.cookieDomain,
MaxAge: -1,
HttpOnly: true,
Secure: h.cookieSecure,
SameSite: http.SameSiteLaxMode,
})
}
// --- requireAuth middleware ---
type ctxKey int
const userIDCtxKey ctxKey = iota
func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(userIDCtxKey).(uuid.UUID)
return v, ok
}
func requireAuth(signer *auth.JWTSigner) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := r.Header.Get("Authorization")
if !strings.HasPrefix(h, "Bearer ") {
writeError(w, http.StatusUnauthorized, "missing bearer token")
return
}
raw := strings.TrimSpace(strings.TrimPrefix(h, "Bearer "))
claims, err := signer.Parse(raw)
if err != nil {
if errors.Is(err, auth.ErrExpiredJWT) {
writeError(w, http.StatusUnauthorized, "token expired")
return
}
writeError(w, http.StatusUnauthorized, "invalid token")
return
}
ctx := context.WithValue(r.Context(), userIDCtxKey, claims.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
+44
View File
@@ -0,0 +1,44 @@
package api
import (
"errors"
"net/http"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
// hostFromContext returns the authed user's id, or writes 401 and returns
// false. Used by host-facing handlers as the first line in the function.
func hostFromContext(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
uid, ok := UserIDFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthenticated")
return uuid.Nil, false
}
return uid, true
}
// requireEventOwner fetches the event and confirms the authed user owns it.
// On mismatch (or missing event) it returns 404 — never 403 — so a cross-
// tenant probe cannot tell the difference between "event doesn't exist" and
// "exists but belongs to someone else".
func requireEventOwner(
w http.ResponseWriter,
r *http.Request,
events *storage.EventRepo,
eventID, hostID uuid.UUID,
) (*domain.Event, bool) {
ev, err := events.GetForHost(r.Context(), eventID, hostID)
if err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return nil, false
}
writeError(w, http.StatusInternalServerError, "failed to load event")
return nil, false
}
return ev, true
}
+206
View File
@@ -0,0 +1,206 @@
package api
import (
"encoding/json"
"errors"
"log/slog"
"net/http"
"strings"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/storage"
)
type billingHandler struct {
logger *slog.Logger
stripe *billing.Client
users *storage.UserRepo
subscriptions *storage.SubscriptionRepo
publicBaseURL string
}
type checkoutSessionRequest struct {
Tier string `json:"tier"`
}
type checkoutSessionResponse struct {
URL string `json:"url"`
}
// POST /billing/checkout-session — returns the Stripe Checkout URL the
// frontend redirects the host to. Mints a Stripe customer on first use
// and persists it so repeat calls reuse the same customer.
func (h *billingHandler) checkoutSession(w http.ResponseWriter, r *http.Request) {
if !h.stripeEnabled(w) {
return
}
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
var req checkoutSessionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
tier := billing.Tier(strings.ToLower(req.Tier))
if tier != billing.TierPro && tier != billing.TierBusiness {
writeError(w, http.StatusBadRequest, "tier must be 'pro' or 'business'")
return
}
price, err := h.stripe.PriceFor(tier)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "this tier is not configured yet — contact support")
return
}
user, err := h.users.GetByID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
existingCustomerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load billing record")
return
}
customerID, err := h.stripe.CreateOrGetCustomer(hostID.String(), user.Email, user.Name, existingCustomerID)
if err != nil {
h.logger.Error("stripe customer", "err", err)
writeError(w, http.StatusBadGateway, "stripe customer error")
return
}
if existingCustomerID == "" {
// First time — write a placeholder row so the customer id sticks.
if _, err := h.subscriptions.Upsert(r.Context(), storage.UpsertParams{
UserID: hostID,
StripeCustomerID: customerID,
}); err != nil {
h.logger.Error("upsert sub placeholder", "err", err)
}
}
base := strings.TrimRight(h.publicBaseURL, "/")
url, err := h.stripe.CreateCheckoutSession(billing.CheckoutSessionParams{
CustomerID: customerID,
PriceID: price,
SuccessURL: base + "/dashboard?billing=success",
CancelURL: base + "/dashboard?billing=cancelled",
})
if err != nil {
h.logger.Error("stripe checkout session", "err", err)
writeError(w, http.StatusBadGateway, "stripe checkout error")
return
}
writeJSON(w, http.StatusOK, checkoutSessionResponse{URL: url})
}
type portalSessionResponse struct {
URL string `json:"url"`
}
// POST /billing/portal — returns the customer portal URL so the user
// can manage their payment method, view invoices, or cancel. 404 when
// the user has no Stripe customer yet (they're still on free).
func (h *billingHandler) portalSession(w http.ResponseWriter, r *http.Request) {
if !h.stripeEnabled(w) {
return
}
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
customerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load billing record")
return
}
if customerID == "" {
writeError(w, http.StatusNotFound, "no billing account yet — subscribe first")
return
}
url, err := h.stripe.CreatePortalSession(customerID, strings.TrimRight(h.publicBaseURL, "/")+"/dashboard")
if err != nil {
h.logger.Error("stripe portal", "err", err)
writeError(w, http.StatusBadGateway, "stripe portal error")
return
}
writeJSON(w, http.StatusOK, portalSessionResponse{URL: url})
}
type subscriptionStatusResponse struct {
Tier string `json:"tier"`
Status string `json:"status"`
CurrentPeriodEnd string `json:"current_period_end,omitempty"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
Limits struct {
EventsPerMonth int `json:"events_per_month"`
GuestsPerEvent int `json:"guests_per_event"`
} `json:"limits"`
Usage struct {
EventsThisMonth int `json:"events_this_month"`
} `json:"usage"`
PortalAvailable bool `json:"portal_available"`
}
// GET /billing/status — returns the host's current tier + limits +
// usage. The frontend uses this to render the billing page and the
// 402-modal copy ("you used X of Y events this month").
func (h *billingHandler) status(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
tier := billing.TierFree
status := "active"
var periodEnd string
cancelAtPeriodEnd := false
portalAvailable := false
if h.subscriptions != nil {
sub, err := h.subscriptions.GetActiveByUser(r.Context(), hostID)
switch {
case err == nil:
tier = billing.Tier(sub.Tier)
status = sub.Status
cancelAtPeriodEnd = sub.CancelAtPeriodEnd
if sub.CurrentPeriodEnd != nil {
periodEnd = sub.CurrentPeriodEnd.Format("2006-01-02T15:04:05Z")
}
portalAvailable = sub.StripeCustomerID != ""
case errors.Is(err, storage.ErrSubscriptionNotFound):
// Free tier — leave defaults.
default:
writeError(w, http.StatusInternalServerError, "failed to load subscription")
return
}
}
limits := billing.LimitsFor(tier)
events, _ := h.subscriptions.CountEventsInCurrentMonth(r.Context(), hostID)
resp := subscriptionStatusResponse{
Tier: string(tier),
Status: status,
CurrentPeriodEnd: periodEnd,
CancelAtPeriodEnd: cancelAtPeriodEnd,
PortalAvailable: portalAvailable,
}
resp.Limits.EventsPerMonth = limits.EventsPerMonth
resp.Limits.GuestsPerEvent = limits.GuestsPerEvent
resp.Usage.EventsThisMonth = events
writeJSON(w, http.StatusOK, resp)
}
// stripeEnabled returns true if the billing client is configured, else
// writes 503 and returns false. The /billing/status path skips this so
// the frontend can render a "free tier" page in dev environments.
func (h *billingHandler) stripeEnabled(w http.ResponseWriter) bool {
if h.stripe == nil || !h.stripe.Enabled() {
writeError(w, http.StatusServiceUnavailable, "billing is not configured on this instance")
return false
}
return true
}
+157
View File
@@ -0,0 +1,157 @@
package api
import (
"context"
"errors"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/storage"
)
// tierEnforcer wraps the SubscriptionRepo with policy decisions. Lives
// here (not in storage) because the policy is HTTP-shaped: we map the
// outcome to 402 + an upgrade URL.
type tierEnforcer struct {
subs *storage.SubscriptionRepo
publicBaseURL string
}
func newTierEnforcer(subs *storage.SubscriptionRepo, publicBaseURL string) *tierEnforcer {
return &tierEnforcer{subs: subs, publicBaseURL: publicBaseURL}
}
// currentTier returns the host's effective tier. ErrSubscriptionNotFound
// means "no granting subscription on file" → free. Other DB errors
// bubble up.
func (e *tierEnforcer) currentTier(ctx context.Context, hostID uuid.UUID) (billing.Tier, error) {
if e == nil || e.subs == nil {
return billing.TierFree, nil
}
sub, err := e.subs.GetActiveByUser(ctx, hostID)
if err != nil {
if errors.Is(err, storage.ErrSubscriptionNotFound) {
return billing.TierFree, nil
}
return "", err
}
if !billing.StatusGrantsAccess(sub.Status) {
return billing.TierFree, nil
}
tier := billing.Tier(sub.Tier)
if !tier.Valid() {
return billing.TierFree, nil
}
return tier, nil
}
// allowEventCreate verifies the host's monthly event budget. Returns
// true when the request may proceed. On denial it writes a 402 with the
// upgrade hint and returns false.
func (e *tierEnforcer) allowEventCreate(w http.ResponseWriter, r *http.Request, hostID uuid.UUID) bool {
if e == nil || e.subs == nil {
return true
}
tier, err := e.currentTier(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to check plan")
return false
}
limit := billing.LimitsFor(tier).EventsPerMonth
if limit < 0 {
return true
}
used, err := e.subs.CountEventsInCurrentMonth(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to count events")
return false
}
if used >= limit {
e.writePaymentRequired(w, "events_per_month", tier, used, limit,
"You've reached your monthly event limit on the "+strings.ToUpper(string(tier))+" plan.")
return false
}
return true
}
// allowGuestCreate verifies the per-event guest cap. Same shape as
// allowEventCreate.
func (e *tierEnforcer) allowGuestCreate(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID) bool {
if e == nil || e.subs == nil {
return true
}
tier, err := e.currentTier(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to check plan")
return false
}
limit := billing.LimitsFor(tier).GuestsPerEvent
if limit < 0 {
return true
}
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to count guests")
return false
}
if used >= limit {
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
"This event has reached the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
return false
}
return true
}
// allowGuestImport is the CSV-import variant: check the cap against
// existing + incoming row count up-front, before we start the
// transaction. Dedup may shrink the actual insert count later — that's
// OK, we just stay on the safe side.
func (e *tierEnforcer) allowGuestImport(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID, incoming int) bool {
if e == nil || e.subs == nil || incoming == 0 {
return true
}
tier, err := e.currentTier(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to check plan")
return false
}
limit := billing.LimitsFor(tier).GuestsPerEvent
if limit < 0 {
return true
}
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to count guests")
return false
}
if used+incoming > limit {
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
"This import would exceed the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
return false
}
return true
}
type paymentRequiredBody struct {
Error string `json:"error"`
Reason string `json:"reason"`
Tier string `json:"tier"`
Used int `json:"used"`
Limit int `json:"limit"`
UpgradeURL string `json:"upgrade_url"`
}
func (e *tierEnforcer) writePaymentRequired(w http.ResponseWriter, reason string, tier billing.Tier, used, limit int, msg string) {
body := paymentRequiredBody{
Error: msg,
Reason: reason,
Tier: string(tier),
Used: used,
Limit: limit,
UpgradeURL: strings.TrimRight(e.publicBaseURL, "/") + "/dashboard/billing",
}
writeJSON(w, http.StatusPaymentRequired, body)
}
+177
View File
@@ -0,0 +1,177 @@
package api
import (
"errors"
"io"
"net/http"
"github.com/alchemistkay/guestguard/internal/csvimport"
"github.com/alchemistkay/guestguard/internal/storage"
)
const (
// 1 MB cap on uploads. With ~200 bytes per row that's ~5,000 guests —
// matches the row cap in csvimport.DefaultMaxRows.
csvMaxBytes = 1 << 20
)
type csvImportHandler struct {
guests *storage.GuestRepo
events *storage.EventRepo
enforcer *tierEnforcer
}
type importResponse struct {
Added int `json:"added"`
Skipped int `json:"skipped"`
SkippedEmails []string `json:"skipped_emails,omitempty"`
Errors []csvimport.RowError `json:"errors,omitempty"`
TotalCount int `json:"total_count"`
}
type previewResponse struct {
Rows []csvimport.Row `json:"rows"`
Errors []csvimport.RowError `json:"errors,omitempty"`
TotalCount int `json:"total_count"`
}
// POST /events/{id}/guests/import/preview — parse + validate but don't write.
// Used by the frontend to show a "is this what you meant?" table before commit.
func (h *csvImportHandler) preview(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
body, ok := readCSVUpload(w, r)
if !ok {
return
}
defer body.Close()
res, err := csvimport.Parse(body, csvimport.Options{})
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, previewResponse{
Rows: res.Rows,
Errors: res.Errors,
TotalCount: res.TotalCount,
})
}
// POST /events/{id}/guests/import — parse, validate, and commit valid rows
// in a single transaction. Rows with row-level errors are reported back
// but don't prevent the rest from importing.
func (h *csvImportHandler) commit(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
body, ok := readCSVUpload(w, r)
if !ok {
return
}
defer body.Close()
parsed, err := csvimport.Parse(body, csvimport.Options{})
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
// Plan enforcement: prevent an import that, even if perfectly
// dedup-free, would exceed the per-event guest cap. We check upfront
// (current count + parsed-row count) and reject with 402 if it'd
// overflow. False positives on dedup-heavy CSVs are acceptable —
// host can dedupe and re-upload.
if !h.enforcer.allowGuestImport(w, r, hostID, eventID, len(parsed.Rows)) {
return
}
rows := make([]storage.BulkImportRow, 0, len(parsed.Rows))
for _, r := range parsed.Rows {
rows = append(rows, storage.BulkImportRow{
Name: r.Name,
Email: r.Email,
Phone: r.Phone,
PlusOnes: r.PlusOnes,
})
}
res, err := h.guests.BulkImportGuests(r.Context(), eventID, rows)
if err != nil {
writeError(w, http.StatusInternalServerError, "import failed")
return
}
writeJSON(w, http.StatusOK, importResponse{
Added: res.Added,
Skipped: res.Skipped,
SkippedEmails: res.SkippedEmails,
Errors: parsed.Errors,
TotalCount: parsed.TotalCount,
})
}
// GET /events/{id}/guests/import/template — download a sample CSV. Auth is
// applied at the route level; ownership is verified so an attacker can't
// probe the existence of an event by hitting this endpoint.
func (h *csvImportHandler) template(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-import-template.csv"`)
_, _ = w.Write([]byte(csvimport.TemplateCSV))
}
// readCSVUpload returns the multipart file body (capped at csvMaxBytes) or
// writes an error and returns (nil, false). Accepted shapes:
//
// - multipart/form-data with a "file" field (the drag-drop UI uses this)
// - any other Content-Type — the raw body is treated as the CSV (curl-friendly)
func readCSVUpload(w http.ResponseWriter, r *http.Request) (io.ReadCloser, bool) {
r.Body = http.MaxBytesReader(w, r.Body, csvMaxBytes)
if err := r.ParseMultipartForm(csvMaxBytes); err == nil && r.MultipartForm != nil {
files := r.MultipartForm.File["file"]
if len(files) == 0 {
writeError(w, http.StatusBadRequest, `form field "file" is required`)
return nil, false
}
f, err := files[0].Open()
if err != nil {
writeError(w, http.StatusBadRequest, "cannot read uploaded file")
return nil, false
}
return f, true
} else if err != nil && !errors.Is(err, http.ErrNotMultipart) {
writeError(w, http.StatusBadRequest, "invalid multipart body")
return nil, false
}
// Fall through: raw body as CSV.
return r.Body, true
}
+38 -27
View File
@@ -16,10 +16,10 @@ import (
type eventHandler struct {
repo *storage.EventRepo
enforcer *tierEnforcer
}
type createEventRequest struct {
HostID string `json:"host_id"`
Name string `json:"name"`
Slug string `json:"slug"`
EventDate time.Time `json:"event_date"`
@@ -32,6 +32,14 @@ type createEventRequest struct {
var slugRe = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
if !h.enforcer.allowEventCreate(w, r, hostID) {
return
}
var req createEventRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
@@ -51,12 +59,6 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
return
}
hostID, err := uuid.Parse(req.HostID)
if err != nil {
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
return
}
status := domain.EventStatus(req.Status)
if status == "" {
status = domain.EventStatusDraft
@@ -89,37 +91,31 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
}
func (h *eventHandler) get(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
id, ok := parseIDParam(w, r, "id")
if !ok {
return
}
ev, err := h.repo.Get(r.Context(), id)
if err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load event")
ev, ok := requireEventOwner(w, r, h.repo, id, hostID)
if !ok {
return
}
writeJSON(w, http.StatusOK, ev)
}
func (h *eventHandler) list(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
q := r.URL.Query()
limit := atoiOr(q.Get("limit"), 50)
offset := atoiOr(q.Get("offset"), 0)
var hostID uuid.UUID
if v := q.Get("host_id"); v != "" {
parsed, err := uuid.Parse(v)
if err != nil {
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
return
}
hostID = parsed
}
events, err := h.repo.List(r.Context(), hostID, limit, offset)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to list events")
@@ -146,6 +142,10 @@ type updateEventRequest struct {
}
func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
id, ok := parseIDParam(w, r, "id")
if !ok {
return
@@ -180,7 +180,7 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
params.Status = &s
}
ev, err := h.repo.Update(r.Context(), id, params)
ev, err := h.repo.Update(r.Context(), id, hostID, params)
if err != nil {
switch {
case errors.Is(err, domain.ErrEventNotFound):
@@ -196,11 +196,15 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
}
func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
id, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if err := h.repo.Delete(r.Context(), id); err != nil {
if err := h.repo.Delete(r.Context(), id, hostID); err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
@@ -212,7 +216,14 @@ func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
}
func parseIDParam(w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, bool) {
raw := r.PathValue(name)
return parseRawUUID(w, name, r.PathValue(name))
}
func parseRawUUID(w http.ResponseWriter, name, raw string) (uuid.UUID, bool) {
if raw == "" {
writeError(w, http.StatusBadRequest, name+" is required")
return uuid.Nil, false
}
id, err := uuid.Parse(raw)
if err != nil {
writeError(w, http.StatusBadRequest, name+" must be a valid uuid")
+103 -5
View File
@@ -12,6 +12,7 @@ import (
type guestHandler struct {
guests *storage.GuestRepo
events *storage.EventRepo
enforcer *tierEnforcer
}
type createGuestRequest struct {
@@ -24,16 +25,18 @@ type createGuestRequest struct {
}
func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, err := h.events.Get(r.Context(), eventID); err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
writeError(w, http.StatusInternalServerError, "failed to load event")
if !h.enforcer.allowGuestCreate(w, r, hostID, eventID) {
return
}
@@ -67,11 +70,106 @@ func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, g)
}
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
type updateGuestRequest struct {
Name *string `json:"name"`
Email *string `json:"email"`
Phone *string `json:"phone"`
PlusOnes *int `json:"plus_ones"`
}
// PATCH /events/{id}/guests/{guest_id} — patch a guest's contact info.
// Fields omitted from the body are left untouched. Empty strings for
// email/phone clear those columns to NULL.
func (h *guestHandler) update(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
guestID, ok := parseIDParam(w, r, "guest_id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
var req updateGuestRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if req.Name != nil && *req.Name == "" {
writeError(w, http.StatusBadRequest, "name cannot be empty")
return
}
if req.PlusOnes != nil && *req.PlusOnes < 0 {
writeError(w, http.StatusBadRequest, "plus_ones must be >= 0")
return
}
g, err := h.guests.Update(r.Context(), eventID, guestID, storage.UpdateGuestParams{
Name: req.Name,
Email: req.Email,
Phone: req.Phone,
PlusOnes: req.PlusOnes,
})
if err != nil {
if errors.Is(err, domain.ErrGuestNotFound) {
writeError(w, http.StatusNotFound, "guest not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to update guest")
return
}
writeJSON(w, http.StatusOK, g)
}
// DELETE /events/{id}/guests/{guest_id} — remove a guest from an event.
// Cascade-deletes their token, rsvp, access logs, notifications.
func (h *guestHandler) delete(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
guestID, ok := parseIDParam(w, r, "guest_id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
if err := h.guests.Delete(r.Context(), eventID, guestID); err != nil {
if errors.Is(err, domain.ErrGuestNotFound) {
writeError(w, http.StatusNotFound, "guest not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to delete guest")
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
q := r.URL.Query()
limit := atoiOr(q.Get("limit"), 100)
offset := atoiOr(q.Get("offset"), 0)
+32
View File
@@ -0,0 +1,32 @@
package api
import (
"net/http"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
type meHandler struct {
users *storage.UserRepo
}
// GET /me — returns the authenticated user. Used by the frontend to bootstrap
// after a page reload (with a fresh access token from /auth/refresh).
func (h *meHandler) get(w http.ResponseWriter, r *http.Request) {
uid, ok := UserIDFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthenticated")
return
}
u, err := h.users.GetByID(r.Context(), uid)
if err != nil {
if err == domain.ErrUserNotFound {
writeError(w, http.StatusUnauthorized, "user not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
writeJSON(w, http.StatusOK, u)
}
+255
View File
@@ -0,0 +1,255 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
"time"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
// privacyHandler holds the GDPR-style "your data, your choice" endpoints:
// data export, account deletion, and terms-acceptance recording.
type privacyHandler struct {
logger *slog.Logger
users *storage.UserRepo
events *storage.EventRepo
guests *storage.GuestRepo
tokens *storage.TokenRepo
rsvps *storage.RSVPRepo
access *storage.AccessLogRepo
notifs *storage.DB // raw pool access for the export queries
refresh *storage.RefreshTokenRepo
}
// DataExport is the shape of the JSON the host downloads from
// GET /me/data-export. We don't paginate or stream — for the scale
// GuestGuard hosts have, a single response is reasonable. If a host
// ever has 100k+ access logs we'll switch to async + email-a-link.
type DataExport struct {
ExportedAt time.Time `json:"exported_at"`
Format string `json:"format"`
User *domain.User `json:"user"`
Events []*domain.Event `json:"events"`
Guests []*domain.Guest `json:"guests"`
Tokens []exportedToken `json:"tokens"`
RSVPs []exportedRSVP `json:"rsvps"`
AccessLogs []exportedAccess `json:"access_logs"`
Notifs []exportedNotif `json:"notifications"`
}
type exportedToken struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
type exportedRSVP struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
Response string `json:"response"`
PlusOnes int `json:"plus_ones"`
SubmittedAt time.Time `json:"submitted_at"`
}
type exportedAccess struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
RiskScore *int `json:"risk_score,omitempty"`
Flagged bool `json:"flagged"`
CreatedAt time.Time `json:"created_at"`
}
type exportedNotif struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
Channel string `json:"channel"`
Type string `json:"type"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// GET /me/data-export — returns every record the system holds about the
// authenticated user. The Content-Disposition header makes browsers
// offer a download rather than rendering inline.
func (h *privacyHandler) dataExport(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
user, err := h.users.GetByID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
export := DataExport{
ExportedAt: time.Now().UTC(),
Format: "guestguard.v1",
User: user,
}
// Events the user hosts.
events, err := h.events.List(r.Context(), hostID, 1000, 0)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load events")
return
}
export.Events = events
// For each event, pull guests + tokens + rsvps + access_logs + notifications.
for _, ev := range events {
guests, err := h.guests.ListByEvent(r.Context(), ev.ID, 5000, 0)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load guests")
return
}
export.Guests = append(export.Guests, guests...)
for _, g := range guests {
// Token (at most one per guest, but query as a list for symmetry).
if err := h.appendTokens(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: tokens", "err", err)
}
if err := h.appendRSVPs(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: rsvps", "err", err)
}
if err := h.appendAccess(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: access", "err", err)
}
if err := h.appendNotifs(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: notifs", "err", err)
}
}
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-data-export.json"`)
writeJSON(w, http.StatusOK, export)
}
func (h *privacyHandler) appendTokens(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, expires_at, status, created_at
FROM tokens WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var t exportedToken
if err := rows.Scan(&t.ID, &t.GuestID, &t.ExpiresAt, &t.Status, &t.CreatedAt); err != nil {
return err
}
out.Tokens = append(out.Tokens, t)
}
return rows.Err()
}
func (h *privacyHandler) appendRSVPs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, response::text, plus_ones, submitted_at
FROM rsvps WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var r exportedRSVP
if err := rows.Scan(&r.ID, &r.GuestID, &r.Response, &r.PlusOnes, &r.SubmittedAt); err != nil {
return err
}
out.RSVPs = append(out.RSVPs, r)
}
return rows.Err()
}
func (h *privacyHandler) appendAccess(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, risk_score, flagged, created_at
FROM access_logs WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var a exportedAccess
var rs *int
if err := rows.Scan(&a.ID, &a.GuestID, &rs, &a.Flagged, &a.CreatedAt); err != nil {
return err
}
a.RiskScore = rs
out.AccessLogs = append(out.AccessLogs, a)
}
return rows.Err()
}
func (h *privacyHandler) appendNotifs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, channel::text, type::text, status::text, created_at
FROM notifications WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var n exportedNotif
if err := rows.Scan(&n.ID, &n.GuestID, &n.Channel, &n.Type, &n.Status, &n.CreatedAt); err != nil {
return err
}
out.Notifs = append(out.Notifs, n)
}
return rows.Err()
}
// DELETE /me — soft-deletes the host's account. All sessions are
// revoked immediately. A hard delete happens via a separate cron 30
// days later (TBD ops work). The user is logged out from all devices
// as a side effect of revoking the refresh tokens.
func (h *privacyHandler) deleteMe(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
if err := h.users.SoftDelete(r.Context(), hostID); err != nil {
if errors.Is(err, domain.ErrUserNotFound) {
writeError(w, http.StatusNotFound, "user not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to delete account")
return
}
// Best-effort: revoke refresh tokens so other sessions log out too.
// Failure here is logged but doesn't roll back the soft-delete — the
// access tokens (JWT) will still expire on their own ~15 minute TTL.
if err := h.refresh.RevokeAllForUser(r.Context(), hostID); err != nil {
h.logger.Warn("delete-me: revoke refresh tokens", "err", err, "user_id", hostID)
}
w.WriteHeader(http.StatusNoContent)
}
// POST /me/accept-terms — records that the authenticated user accepts
// the current ToS + privacy policy. Idempotent. Used by both the
// onboarding gate (existing accounts created before T&C were enforced)
// and any future "we updated our terms" re-acceptance flow.
func (h *privacyHandler) acceptTerms(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
if err := h.users.AcceptTerms(r.Context(), hostID); err != nil {
if errors.Is(err, domain.ErrUserNotFound) {
writeError(w, http.StatusNotFound, "user not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to record acceptance")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "accepted"})
}
+37
View File
@@ -0,0 +1,37 @@
package api
import (
"net/http"
)
// ipKey is the rate-limit key for endpoints scoped by source IP only
// (e.g. POST /auth/signup). XFF/X-Real-IP are honoured because in the
// homelab the API sits behind Traefik.
func ipKey(r *http.Request) string {
return clientIP(r)
}
// pathKey returns a path-parameter as the rate-limit key — used for the
// token-scoped endpoints so an attacker brute-forcing a single token is
// limited regardless of the IPs they rotate through.
func pathKey(name string) KeyFunc {
return func(r *http.Request) string {
return r.PathValue(name)
}
}
// userIDKey extracts the authenticated user id from the request context.
// Returns "" when the route isn't behind requireAuth, in which case the
// middleware bypasses (fail-open) — the route's own auth layer handles
// rejection.
func userIDKey(r *http.Request) string {
uid, ok := UserIDFromContext(r.Context())
if !ok {
return ""
}
return uid.String()
}
// KeyFunc mirrors ratelimit.KeyFunc so call sites don't have to import the
// inner package.
type KeyFunc = func(r *http.Request) string
+241 -20
View File
@@ -5,7 +5,12 @@ import (
"net/http"
"time"
"github.com/redis/go-redis/v9"
"github.com/alchemistkay/guestguard/internal/auth"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/notification"
"github.com/alchemistkay/guestguard/internal/ratelimit"
"github.com/alchemistkay/guestguard/internal/storage"
)
@@ -13,14 +18,24 @@ type Server struct {
logger *slog.Logger
db *storage.DB
hub *Hub
users *userHandler
authH *authHandler
me *meHandler
events *eventHandler
guests *guestHandler
tokens *tokenHandler
rsvps *rsvpHandler
activity *activityHandler
ws *wsHandler
wsTicket *wsTicketHandler
health *healthHandler
signer *auth.JWTSigner
limiter *ratelimit.Limiter
unsub *unsubscribeHandler
webhooks *webhookHandler
csv *csvImportHandler
billing *billingHandler
stripeWH *stripeWebhookHandler
privacy *privacyHandler
}
type ServerDeps struct {
@@ -29,39 +44,128 @@ type ServerDeps struct {
Hub *Hub
AccessPublisher accessPublisher
RSVPPublisher rsvpPublisher
InvitationPublisher invitationPublisher
FraudScorer fraudScorer
TokenTTL time.Duration
// Auth
JWTSecret string
JWTIssuer string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
EmailVerificationTTL time.Duration
PasswordResetTTL time.Duration
PublicBaseURL string
RefreshCookieDomain string
RefreshCookieSecure bool
EmailSender auth.EmailSender
WSTicketTTL time.Duration
// Rate limiting / abuse controls
Redis *redis.Client
LoginLockoutMax int // failed attempts before account lockout (default 5)
LoginFailWindow time.Duration // counter TTL (default 15 min)
// Notifications / unsubscribe
NotificationRepo *notification.Repo
SuppressionRepo *notification.SuppressionRepo
UnsubscribeSigner *notification.UnsubscribeSigner
// Billing (Block F). Nil StripeClient leaves billing disabled — the
// system still boots and runs, all users sit on the free tier with
// its limits enforced; /billing/* returns 503.
StripeClient *billing.Client
}
func NewServer(deps ServerDeps) *Server {
func NewServer(deps ServerDeps) (*Server, error) {
eventRepo := storage.NewEventRepo(deps.DB)
guestRepo := storage.NewGuestRepo(deps.DB)
tokenRepo := storage.NewTokenRepo(deps.DB)
rsvpRepo := storage.NewRSVPRepo(deps.DB)
accessRepo := storage.NewAccessLogRepo(deps.DB)
userRepo := storage.NewUserRepo(deps.DB)
verifRepo := storage.NewEmailVerificationRepo(deps.DB)
resetRepo := storage.NewPasswordResetRepo(deps.DB)
refreshRepo := storage.NewRefreshTokenRepo(deps.DB)
subRepo := storage.NewSubscriptionRepo(deps.DB)
enforcer := newTierEnforcer(subRepo, deps.PublicBaseURL)
signer, err := auth.NewJWTSigner(deps.JWTSecret, deps.AccessTokenTTL, deps.JWTIssuer)
if err != nil {
return nil, err
}
hasher := auth.NewPasswordHasher()
emails := deps.EmailSender
if emails == nil {
emails = auth.LogEmailSender{Logger: deps.Logger}
}
hub := deps.Hub
if hub == nil {
hub = NewHub(deps.Logger)
}
wsTicketTTL := deps.WSTicketTTL
if wsTicketTTL <= 0 {
wsTicketTTL = 60 * time.Second
}
wsTickets := newWSTicketStore(wsTicketTTL)
var limiter *ratelimit.Limiter
var lockout *auth.LockoutTracker
if deps.Redis != nil {
limiter = ratelimit.New(deps.Redis, "gg:rl")
lockoutMax := deps.LoginLockoutMax
if lockoutMax <= 0 {
lockoutMax = 5
}
failWindow := deps.LoginFailWindow
if failWindow <= 0 {
failWindow = 15 * time.Minute
}
lockout = auth.NewLockoutTracker(deps.Redis, lockoutMax, failWindow)
}
authH := newAuthHandler(authHandlerDeps{
Logger: deps.Logger,
Users: userRepo,
Verifications: verifRepo,
Resets: resetRepo,
Refreshes: refreshRepo,
Hasher: hasher,
Signer: signer,
Emails: emails,
Lockout: lockout,
Limiter: limiter,
PublicBaseURL: deps.PublicBaseURL,
EmailVerificationTTL: deps.EmailVerificationTTL,
PasswordResetTTL: deps.PasswordResetTTL,
RefreshTTL: deps.RefreshTokenTTL,
CookieDomain: deps.RefreshCookieDomain,
CookieSecure: deps.RefreshCookieSecure,
})
return &Server{
logger: deps.Logger,
db: deps.DB,
hub: hub,
users: &userHandler{repo: userRepo},
events: &eventHandler{repo: eventRepo},
guests: &guestHandler{guests: guestRepo, events: eventRepo},
authH: authH,
me: &meHandler{users: userRepo},
events: &eventHandler{repo: eventRepo, enforcer: enforcer},
guests: &guestHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
tokens: &tokenHandler{
logger: deps.Logger,
guests: guestRepo,
tokens: tokenRepo,
events: eventRepo,
users: userRepo,
accessLogs: accessRepo,
gen: auth.NewGenerator(),
ttl: deps.TokenTTL,
pub: deps.AccessPublisher,
invitations: deps.InvitationPublisher,
publicBaseURL: deps.PublicBaseURL,
},
rsvps: &rsvpHandler{
logger: deps.Logger,
@@ -78,9 +182,46 @@ func NewServer(deps ServerDeps) *Server {
rsvps: rsvpRepo,
accessLogs: accessRepo,
},
ws: &wsHandler{logger: deps.Logger, hub: hub},
ws: &wsHandler{logger: deps.Logger, hub: hub, tickets: wsTickets},
wsTicket: &wsTicketHandler{tickets: wsTickets, events: eventRepo},
health: &healthHandler{pool: deps.DB.Pool},
}
signer: signer,
limiter: limiter,
unsub: &unsubscribeHandler{
logger: deps.Logger,
signer: deps.UnsubscribeSigner,
suppress: deps.SuppressionRepo,
},
webhooks: &webhookHandler{
logger: deps.Logger,
notifs: deps.NotificationRepo,
suppress: deps.SuppressionRepo,
},
csv: &csvImportHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
billing: &billingHandler{
logger: deps.Logger,
stripe: deps.StripeClient,
users: userRepo,
subscriptions: subRepo,
publicBaseURL: deps.PublicBaseURL,
},
stripeWH: &stripeWebhookHandler{
logger: deps.Logger,
stripe: deps.StripeClient,
subs: subRepo,
},
privacy: &privacyHandler{
logger: deps.Logger,
users: userRepo,
events: eventRepo,
guests: guestRepo,
tokens: tokenRepo,
rsvps: rsvpRepo,
access: accessRepo,
notifs: deps.DB,
refresh: refreshRepo,
},
}, nil
}
func (s *Server) Hub() *Hub { return s.hub }
@@ -91,25 +232,104 @@ func (s *Server) Handler() http.Handler {
mux.HandleFunc("GET /health", s.health.live)
mux.HandleFunc("GET /health/ready", s.health.ready)
mux.HandleFunc("POST /users", s.users.upsert)
// Per-route rate limiters (no-op when Redis isn't wired).
authed := requireAuth(s.signer)
rl := func(name string, limit int, window time.Duration, keyFn KeyFunc, h http.Handler) http.Handler {
if s.limiter == nil {
return h
}
return s.limiter.Middleware(
ratelimit.Rule{Name: name, Limit: limit, Window: window},
keyFn,
s.logger,
)(h)
}
mux.HandleFunc("POST /events", s.events.create)
mux.HandleFunc("GET /events", s.events.list)
mux.HandleFunc("GET /events/{id}", s.events.get)
mux.HandleFunc("PATCH /events/{id}", s.events.update)
mux.HandleFunc("DELETE /events/{id}", s.events.delete)
// Anonymous auth endpoints — POST /auth/login + /auth/forgot-password
// rate-limit inside the handler (key includes the email body field).
mux.Handle("POST /auth/signup",
rl("auth_signup", 5, time.Hour, ipKey, http.HandlerFunc(s.authH.signup)))
mux.HandleFunc("POST /auth/login", s.authH.login)
mux.HandleFunc("POST /auth/refresh", s.authH.refresh)
mux.HandleFunc("POST /auth/logout", s.authH.logout)
mux.HandleFunc("POST /auth/verify-email", s.authH.verifyEmail)
mux.HandleFunc("POST /auth/forgot-password", s.authH.forgotPassword)
mux.HandleFunc("POST /auth/reset-password", s.authH.resetPassword)
mux.HandleFunc("POST /events/{id}/guests", s.guests.create)
mux.HandleFunc("GET /events/{id}/guests", s.guests.list)
mux.Handle("GET /me", authed(http.HandlerFunc(s.me.get)))
mux.Handle("POST /auth/ws-ticket", authed(http.HandlerFunc(s.wsTicket.issue)))
mux.HandleFunc("GET /events/{id}/activity", s.activity.list)
// Privacy / GDPR-style endpoints — host can export their data,
// delete their account, and record terms acceptance from the
// onboarding gate.
mux.Handle("GET /me/data-export", authed(http.HandlerFunc(s.privacy.dataExport)))
mux.Handle("DELETE /me", authed(http.HandlerFunc(s.privacy.deleteMe)))
mux.Handle("POST /me/accept-terms", authed(http.HandlerFunc(s.privacy.acceptTerms)))
mux.HandleFunc("POST /events/{id}/guests/{guest_id}/tokens", s.tokens.issue)
mux.HandleFunc("GET /access/{token}", s.tokens.access)
mux.HandleFunc("POST /rsvp/{token}", s.rsvps.submit)
// Host-facing event/guest/token writes are limited by user_id.
mux.Handle("POST /events",
authed(rl("events_create", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.events.create))))
mux.Handle("GET /events", authed(http.HandlerFunc(s.events.list)))
mux.Handle("GET /events/{id}", authed(http.HandlerFunc(s.events.get)))
mux.Handle("PATCH /events/{id}", authed(http.HandlerFunc(s.events.update)))
mux.Handle("DELETE /events/{id}", authed(http.HandlerFunc(s.events.delete)))
mux.Handle("POST /events/{id}/guests",
authed(rl("guests_create", 1000, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.create))))
mux.Handle("GET /events/{id}/guests", authed(http.HandlerFunc(s.guests.list)))
mux.Handle("PATCH /events/{id}/guests/{guest_id}",
authed(rl("guests_update", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.update))))
mux.Handle("DELETE /events/{id}/guests/{guest_id}",
authed(rl("guests_delete", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.delete))))
// CSV import (Block E). Preview is cheap (no DB writes), so we keep
// its budget separate from commit's daily-row-add limit.
mux.Handle("POST /events/{id}/guests/import/preview",
authed(rl("guests_import_preview", 30, time.Hour, userIDKey, http.HandlerFunc(s.csv.preview))))
mux.Handle("POST /events/{id}/guests/import",
authed(rl("guests_import_commit", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.csv.commit))))
mux.Handle("GET /events/{id}/guests/import/template", authed(http.HandlerFunc(s.csv.template)))
mux.Handle("GET /events/{id}/activity", authed(http.HandlerFunc(s.activity.list)))
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens",
authed(rl("tokens_issue", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.issue))))
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens/rotate",
authed(rl("tokens_rotate", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.rotate))))
mux.Handle("POST /events/{id}/guests/invitations/bulk",
authed(rl("tokens_bulk", 10, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.bulkIssue))))
// Guest-facing endpoints — rate-limited by the access token in the URL
// path so an attacker hammering a single invitation is slowed regardless
// of their source IP.
mux.Handle("GET /access/{token}",
rl("access", 60, time.Hour, pathKey("token"), http.HandlerFunc(s.tokens.access)))
mux.Handle("POST /rsvp/{token}",
rl("rsvp", 10, time.Hour, pathKey("token"), http.HandlerFunc(s.rsvps.submit)))
// WebSocket endpoint authenticates via single-use ticket on the query
// string (see POST /auth/ws-ticket).
mux.HandleFunc("GET /ws/events/{id}", s.ws.handle)
// Unsubscribe (signed token, no auth required — links live in emails).
mux.HandleFunc("GET /unsubscribe/{token}", s.unsub.preview)
mux.HandleFunc("POST /unsubscribe/{token}", s.unsub.confirm)
// Provider webhooks. Signature verification is enforced in the handler
// once GG_TWILIO_AUTH_TOKEN / GG_SES_WEBHOOK_SECRET are set.
mux.HandleFunc("POST /webhooks/twilio/status", s.webhooks.twilio)
mux.HandleFunc("POST /webhooks/ses/notifications", s.webhooks.ses)
// Billing (Block F). /billing/status is safe for everyone — returns
// free tier defaults when Stripe is unconfigured or the user has no
// subscription, so the frontend's plan page always has something to
// render. The action endpoints (checkout, portal) return 503 in dev
// without Stripe credentials.
mux.Handle("GET /billing/status", authed(http.HandlerFunc(s.billing.status)))
mux.Handle("POST /billing/checkout-session", authed(http.HandlerFunc(s.billing.checkoutSession)))
mux.Handle("POST /billing/portal", authed(http.HandlerFunc(s.billing.portalSession)))
mux.HandleFunc("POST /webhooks/stripe", s.stripeWH.handle)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, "not found")
})
@@ -128,9 +348,10 @@ func corsMiddleware(next http.Handler) http.Handler {
origin := r.Header.Get("Origin")
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Device-Fingerprint")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Device-Fingerprint")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
+163
View File
@@ -0,0 +1,163 @@
package api
import (
"encoding/json"
"io"
"log/slog"
"net/http"
"time"
"github.com/stripe/stripe-go/v82"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/storage"
)
// stripeWebhookHandler accepts and verifies Stripe events, then
// projects subscription lifecycle changes onto the subscriptions table.
// We track only what middleware needs to decide access — tier + status +
// period bounds. Invoice events (payment failed / succeeded) are logged
// for observability; dunning automation lands in Block F3.
type stripeWebhookHandler struct {
logger *slog.Logger
stripe *billing.Client
subs *storage.SubscriptionRepo
}
// POST /webhooks/stripe — signature-verified Stripe event sink.
func (h *stripeWebhookHandler) handle(w http.ResponseWriter, r *http.Request) {
if h.stripe == nil || !h.stripe.Enabled() {
// Not configured on this instance — reject so a misrouted event
// isn't silently swallowed. Stripe will retry which is harmless.
w.WriteHeader(http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
writeError(w, http.StatusBadRequest, "read body")
return
}
defer r.Body.Close()
event, err := h.stripe.VerifyWebhook(body, r.Header.Get("Stripe-Signature"))
if err != nil {
h.logger.Warn("stripe webhook signature failed", "err", err)
writeError(w, http.StatusBadRequest, "invalid signature")
return
}
switch event.Type {
case "customer.subscription.created", "customer.subscription.updated":
h.applySubscription(r, event)
case "customer.subscription.deleted":
h.applySubscriptionDeleted(r, event)
case "invoice.payment_succeeded":
// Clear past_due if Stripe says payment caught up. Most flows are
// already covered by the subscription.updated event Stripe also
// fires — this is belt-and-braces.
h.applySubscription(r, event)
case "invoice.payment_failed":
h.logger.Warn("stripe invoice payment failed", "event_id", event.ID)
// Subscription.status will flip to past_due via the
// subscription.updated event Stripe fires alongside.
default:
h.logger.Debug("stripe event ignored", "type", event.Type)
}
w.WriteHeader(http.StatusOK)
}
// applySubscription patches the subscriptions row keyed by Stripe
// customer id. Best-effort — failures here are logged but don't NACK
// the webhook (Stripe would retry forever and the row would never
// converge).
func (h *stripeWebhookHandler) applySubscription(r *http.Request, event stripe.Event) {
var sub stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
// Some invoice events carry an Invoice payload — try to extract
// the subscription id from there and short-circuit on status.
h.logger.Debug("stripe webhook: not a subscription payload", "type", event.Type, "err", err)
return
}
if sub.Customer == nil || sub.Customer.ID == "" {
h.logger.Warn("stripe webhook: subscription has no customer", "subscription", sub.ID)
return
}
tier := tierFromSubscription(&sub)
status := string(sub.Status)
cancelAtPeriodEnd := sub.CancelAtPeriodEnd
// As of API 2024-10-28, current_period_end lives on the subscription
// item, not the subscription. We pick the earliest item's end — for
// single-item subscriptions (our case) that's the canonical one.
var periodEnd *time.Time
for _, item := range sub.Items.Data {
if item.CurrentPeriodEnd > 0 {
t := time.Unix(item.CurrentPeriodEnd, 0).UTC()
if periodEnd == nil || t.Before(*periodEnd) {
periodEnd = &t
}
}
}
subID := sub.ID
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
StripeSubscriptionID: &subID,
Tier: stringPtr(string(tier)),
Status: &status,
CurrentPeriodEnd: periodEnd,
CancelAtPeriodEnd: &cancelAtPeriodEnd,
}); err != nil {
h.logger.Error("stripe webhook: update subscription failed", "err", err)
}
}
func (h *stripeWebhookHandler) applySubscriptionDeleted(r *http.Request, event stripe.Event) {
var sub stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
h.logger.Warn("stripe webhook: bad deleted payload", "err", err)
return
}
if sub.Customer == nil {
return
}
status := "canceled"
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
Status: &status,
}); err != nil {
h.logger.Error("stripe webhook: mark canceled failed", "err", err)
}
}
// tierFromSubscription inspects the Stripe price metadata to figure out
// which GuestGuard tier this subscription corresponds to. We read a
// price-level metadata key `gg_tier` (set in the Stripe dashboard when
// you create the Price). Fallback: free.
func tierFromSubscription(sub *stripe.Subscription) billing.Tier {
if sub == nil || len(sub.Items.Data) == 0 {
return billing.TierFree
}
for _, item := range sub.Items.Data {
if item.Price == nil {
continue
}
if v, ok := item.Price.Metadata["gg_tier"]; ok {
t := billing.Tier(v)
if t.Valid() {
return t
}
}
// Heuristic fallback for tests / unconfigured prices: look at the
// recurring interval and amount tier.
if item.Price.Recurring != nil && item.Price.UnitAmount >= 19900 {
return billing.TierBusiness
}
if item.Price.Recurring != nil && item.Price.UnitAmount >= 4900 {
return billing.TierPro
}
}
return billing.TierFree
}
func stringPtr(s string) *string { return &s }
+315
View File
@@ -2,6 +2,7 @@ package api
import (
"context"
"encoding/json"
"errors"
"log/slog"
"net/http"
@@ -20,25 +21,38 @@ type accessPublisher interface {
PublishAccessAttempted(ctx context.Context, evt natspub.AccessAttempted) error
}
type invitationPublisher interface {
PublishInvitationSend(ctx context.Context, evt natspub.InvitationSend) error
}
type tokenHandler struct {
logger *slog.Logger
guests *storage.GuestRepo
tokens *storage.TokenRepo
events *storage.EventRepo
users *storage.UserRepo
accessLogs *storage.AccessLogRepo
gen *auth.Generator
ttl time.Duration
pub accessPublisher
invitations invitationPublisher
publicBaseURL string
}
type issueTokenResponse struct {
Token string `json:"token"`
TokenID uuid.UUID `json:"token_id"`
Meta *domain.Token `json:"meta"`
InvitationQueued bool `json:"invitation_queued"`
InvitationLink string `json:"invitation_link"`
}
// POST /events/{id}/guests/{guest_id}/tokens — issue a token for the guest.
func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
@@ -47,6 +61,10 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
if !ok {
return
}
guest, err := h.guests.Get(r.Context(), guestID)
if err != nil {
@@ -78,10 +96,307 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
return
}
link := h.invitationLink(raw)
invitationQueued := h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
writeJSON(w, http.StatusCreated, issueTokenResponse{
Token: raw,
TokenID: tk.ID,
Meta: tk,
InvitationQueued: invitationQueued,
InvitationLink: link,
})
}
// queueInvitation publishes an invitation.send event so the notifier can
// dispatch a branded email. Best-effort: if any step fails we log and
// return false rather than failing the whole token-issue request — the
// host still has the raw URL in the response and can re-trigger sending.
func (h *tokenHandler) queueInvitation(
ctx context.Context,
event *domain.Event,
guest *domain.Guest,
tk *domain.Token,
hostID uuid.UUID,
rawToken string,
) bool {
if h.invitations == nil {
return false
}
if guest.Email == nil || *guest.Email == "" {
// Phone-only / nameless guests get no email — host shares the link
// manually. Show that on the UI so it's not a silent surprise.
return false
}
hostName := ""
if h.users != nil {
if host, err := h.users.GetByID(ctx, hostID); err == nil && host != nil {
hostName = host.Name
}
}
evt := natspub.InvitationSend{
EventID: event.ID,
GuestID: guest.ID,
TokenID: tk.ID,
GuestName: guest.Name,
GuestEmail: *guest.Email,
HostName: hostName,
EventName: event.Name,
Venue: event.Venue,
EventDate: event.EventDate,
Link: h.invitationLink(rawToken),
IssuedAt: time.Now().UTC(),
}
if err := h.invitations.PublishInvitationSend(ctx, evt); err != nil {
h.logger.Warn("publish invitation.send (continuing)", "err", err, "guest_id", guest.ID)
return false
}
return true
}
// invitationLink renders the public RSVP URL the guest clicks from their
// inbox. publicBaseURL is the externally-reachable host (set via
// GG_PUBLIC_BASE_URL); access via /rsvp/<token> is intentional — the
// frontend page rsvp/[token].vue catches the raw token.
func (h *tokenHandler) invitationLink(raw string) string {
base := h.publicBaseURL
if base == "" {
base = "http://localhost:3000"
}
return base + "/rsvp/" + raw
}
// bulkIssueRequest is the optional JSON body for the bulk-invite call.
// An empty body (or missing GuestIDs) means "every guest on the event
// who doesn't already have a token".
type bulkIssueRequest struct {
GuestIDs []string `json:"guest_ids"`
}
type bulkIssueItemError struct {
GuestID string `json:"guest_id"`
Reason string `json:"reason"`
}
// bulkIssueToken is one minted invitation. The raw token is returned so
// the host's UI can offer a "copy link" affordance after a bulk send
// (especially for guests with no email on file) without making another
// round-trip. Same data the per-guest issue endpoint already exposes,
// scoped to the host who owns the event.
type bulkIssueToken struct {
GuestID string `json:"guest_id"`
Token string `json:"token"`
InvitationQueued bool `json:"invitation_queued"`
InvitationLink string `json:"invitation_link"`
}
type bulkIssueResponse struct {
Issued int `json:"issued"`
Queued int `json:"queued"`
SkippedExisting int `json:"skipped_existing"`
SkippedNoEmail int `json:"skipped_no_email"`
Tokens []bulkIssueToken `json:"tokens,omitempty"`
Errors []bulkIssueItemError `json:"errors,omitempty"`
}
// POST /events/{id}/guests/invitations/bulk — generate tokens for every
// eligible guest (or the explicit subset) on the event and queue an
// invitation email for those with an address. Best-effort: any per-guest
// error is reported in the response and doesn't abort the rest.
func (h *tokenHandler) bulkIssue(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
if !ok {
return
}
var req bulkIssueRequest
if r.ContentLength > 0 {
// Body is optional; only decode when something was actually sent.
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
}
var onlyIDs []uuid.UUID
if len(req.GuestIDs) > 0 {
onlyIDs = make([]uuid.UUID, 0, len(req.GuestIDs))
for _, raw := range req.GuestIDs {
id, err := uuid.Parse(raw)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid guest id: "+raw)
return
}
onlyIDs = append(onlyIDs, id)
}
}
guests, err := h.guests.ListGuestsForInvitation(r.Context(), eventID, onlyIDs)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load guests")
return
}
hostName := ""
if h.users != nil {
if host, err := h.users.GetByID(r.Context(), hostID); err == nil && host != nil {
hostName = host.Name
}
}
resp := bulkIssueResponse{}
for _, g := range guests {
if g.HasToken {
resp.SkippedExisting++
continue
}
raw, hash, err := h.gen.Generate()
if err != nil {
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "mint token failed"})
continue
}
tk, err := h.tokens.Create(r.Context(), storage.CreateTokenParams{
GuestID: g.ID,
TokenHash: hash,
ExpiresAt: time.Now().UTC().Add(h.ttl),
})
if err != nil {
// Likely a race against the unique constraint (someone else
// issued in parallel) — surface but don't fail the batch.
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: err.Error()})
continue
}
resp.Issued++
link := h.invitationLink(raw)
tokenInfo := bulkIssueToken{
GuestID: g.ID.String(),
Token: raw,
InvitationLink: link,
}
if g.Email == "" {
resp.SkippedNoEmail++
resp.Tokens = append(resp.Tokens, tokenInfo)
continue
}
evt := natspub.InvitationSend{
EventID: event.ID,
GuestID: g.ID,
TokenID: tk.ID,
GuestName: g.Name,
GuestEmail: g.Email,
HostName: hostName,
EventName: event.Name,
Venue: event.Venue,
EventDate: event.EventDate,
Link: h.invitationLink(raw),
IssuedAt: time.Now().UTC(),
}
if h.invitations == nil {
resp.Tokens = append(resp.Tokens, tokenInfo)
continue
}
if err := h.invitations.PublishInvitationSend(r.Context(), evt); err != nil {
h.logger.Warn("publish invitation.send (bulk, continuing)", "err", err, "guest_id", g.ID)
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "publish failed"})
resp.Tokens = append(resp.Tokens, tokenInfo)
continue
}
resp.Queued++
tokenInfo.InvitationQueued = true
resp.Tokens = append(resp.Tokens, tokenInfo)
}
writeJSON(w, http.StatusOK, resp)
}
type rotateTokenRequest struct {
// SendEmail asks the notifier to re-deliver the invitation. False
// means "just give me a fresh link" — typical for phone-only guests
// where the host shares the new URL via SMS.
SendEmail bool `json:"send_email"`
}
// POST /events/{id}/guests/{guest_id}/tokens/rotate — invalidate the
// guest's existing invitation link and mint a fresh one. Optionally
// re-publishes invitation.send so the notifier re-delivers via email.
// The old URL stops working immediately.
func (h *tokenHandler) rotate(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
guestID, ok := parseIDParam(w, r, "guest_id")
if !ok {
return
}
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
if !ok {
return
}
guest, err := h.guests.Get(r.Context(), guestID)
if err != nil {
if errors.Is(err, domain.ErrGuestNotFound) {
writeError(w, http.StatusNotFound, "guest not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load guest")
return
}
if guest.EventID != eventID {
writeError(w, http.StatusNotFound, "guest not found in event")
return
}
var req rotateTokenRequest
if r.ContentLength > 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
}
raw, hash, err := h.gen.Generate()
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate token")
return
}
tk, err := h.tokens.RotateForGuest(r.Context(), storage.CreateTokenParams{
GuestID: guestID,
TokenHash: hash,
ExpiresAt: time.Now().UTC().Add(h.ttl),
})
if err != nil {
h.logger.Error("rotate token", "err", err, "guest_id", guestID)
writeError(w, http.StatusInternalServerError, "failed to rotate token")
return
}
link := h.invitationLink(raw)
invitationQueued := false
if req.SendEmail {
invitationQueued = h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
}
writeJSON(w, http.StatusOK, issueTokenResponse{
Token: raw,
TokenID: tk.ID,
Meta: tk,
InvitationQueued: invitationQueued,
InvitationLink: link,
})
}
+50
View File
@@ -0,0 +1,50 @@
package api
import (
"log/slog"
"net/http"
"github.com/alchemistkay/guestguard/internal/notification"
)
type unsubscribeHandler struct {
logger *slog.Logger
signer *notification.UnsubscribeSigner
suppress *notification.SuppressionRepo
}
// GET /unsubscribe/{token} — surface the email address that the token
// belongs to so the frontend can show a confirmation page. Honoured even
// before the user clicks "Confirm" so they see what's being unsubscribed.
func (h *unsubscribeHandler) preview(w http.ResponseWriter, r *http.Request) {
if h.signer == nil {
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
return
}
email, err := h.signer.Verify(r.PathValue("token"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
return
}
writeJSON(w, http.StatusOK, map[string]string{"email": email})
}
// POST /unsubscribe/{token} — add the email to the suppression list.
// Idempotent: clicking the link twice keeps the existing entry.
func (h *unsubscribeHandler) confirm(w http.ResponseWriter, r *http.Request) {
if h.signer == nil || h.suppress == nil {
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
return
}
email, err := h.signer.Verify(r.PathValue("token"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
return
}
if err := h.suppress.Add(r.Context(), email, "user clicked unsubscribe", notification.SuppressionUser); err != nil {
h.logger.Error("add suppression", "err", err, "email", email)
writeError(w, http.StatusInternalServerError, "failed to unsubscribe")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "unsubscribed", "email": email})
}
-55
View File
@@ -1,55 +0,0 @@
package api
import (
"encoding/json"
"errors"
"net/http"
"net/mail"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
type userHandler struct {
repo *storage.UserRepo
}
type upsertUserRequest struct {
Email string `json:"email"`
Name string `json:"name"`
}
// POST /users — idempotent: returns the existing user if the email already
// exists, creates one otherwise. This keeps the demo flow simple without
// requiring real auth.
func (h *userHandler) upsert(w http.ResponseWriter, r *http.Request) {
var req upsertUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if _, err := mail.ParseAddress(req.Email); err != nil {
writeError(w, http.StatusBadRequest, "email is invalid")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
u, err := h.repo.Create(r.Context(), req.Email, req.Name)
if err == nil {
writeJSON(w, http.StatusCreated, u)
return
}
if errors.Is(err, domain.ErrEmailTaken) {
existing, getErr := h.repo.GetByEmail(r.Context(), req.Email)
if getErr != nil {
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
writeJSON(w, http.StatusOK, existing)
return
}
writeError(w, http.StatusInternalServerError, "failed to create user")
}
+145
View File
@@ -0,0 +1,145 @@
package api
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"github.com/alchemistkay/guestguard/internal/notification"
)
// webhookHandler accepts provider status notifications and reflects them
// onto the notifications table + suppression list.
//
// Signature verification is intentionally a TODO until the user provisions
// real Twilio + SES creds — verifying against test fixtures alone would
// give a false sense of security. The endpoint is therefore *not* exposed
// publicly until the deployment is ready.
type webhookHandler struct {
logger *slog.Logger
notifs *notification.Repo
suppress *notification.SuppressionRepo
}
// POST /webhooks/twilio/status — Twilio status callback (form-encoded).
// Fields we care about: MessageSid, MessageStatus (sent|delivered|
// undelivered|failed), ErrorCode, To.
func (h *webhookHandler) twilio(w http.ResponseWriter, r *http.Request) {
if h.notifs == nil {
w.WriteHeader(http.StatusNoContent)
return
}
// TODO(blockD2): verify X-Twilio-Signature with GG_TWILIO_AUTH_TOKEN.
if err := r.ParseForm(); err != nil {
writeError(w, http.StatusBadRequest, "invalid form")
return
}
sid := r.PostForm.Get("MessageSid")
status := r.PostForm.Get("MessageStatus")
if sid == "" || status == "" {
writeError(w, http.StatusBadRequest, "missing MessageSid / MessageStatus")
return
}
ctx := r.Context()
switch status {
case "delivered":
_ = h.notifs.MarkDelivered(ctx, sid)
case "undelivered", "failed":
_ = h.notifs.MarkBounce(ctx, sid, "permanent")
}
h.logger.Info("twilio status callback", "sid", sid, "status", status)
w.WriteHeader(http.StatusNoContent)
}
// POST /webhooks/ses/notifications — SNS-delivered SES notification (JSON).
// Handles the two shapes SES uses: bounce + complaint events. Each event
// carries the messageId we stored in provider_message_id and an array of
// affected recipients.
func (h *webhookHandler) ses(w http.ResponseWriter, r *http.Request) {
if h.notifs == nil {
w.WriteHeader(http.StatusNoContent)
return
}
// TODO(blockD2): verify SNS signature using the cert URL field.
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
writeError(w, http.StatusBadRequest, "read body")
return
}
defer r.Body.Close()
var envelope struct {
Type string `json:"Type"` // "Notification" | "SubscriptionConfirmation"
Message string `json:"Message"` // stringified JSON for Notification
}
if err := json.Unmarshal(body, &envelope); err != nil {
writeError(w, http.StatusBadRequest, "invalid json envelope")
return
}
if envelope.Type == "SubscriptionConfirmation" {
// Confirmed by visiting SubscribeURL — manual op-side step.
h.logger.Info("ses subscription confirmation received (manual confirm required)")
w.WriteHeader(http.StatusNoContent)
return
}
if envelope.Message == "" {
w.WriteHeader(http.StatusNoContent)
return
}
var inner struct {
NotificationType string `json:"notificationType"` // "Bounce" | "Complaint" | "Delivery"
Mail struct {
MessageID string `json:"messageId"`
} `json:"mail"`
Bounce struct {
BounceType string `json:"bounceType"` // "Permanent" | "Transient"
BouncedRecipients []struct {
EmailAddress string `json:"emailAddress"`
} `json:"bouncedRecipients"`
} `json:"bounce"`
Complaint struct {
ComplainedRecipients []struct {
EmailAddress string `json:"emailAddress"`
} `json:"complainedRecipients"`
} `json:"complaint"`
}
if err := json.Unmarshal([]byte(envelope.Message), &inner); err != nil {
writeError(w, http.StatusBadRequest, "invalid inner json")
return
}
ctx := r.Context()
switch inner.NotificationType {
case "Bounce":
bt := "transient"
if inner.Bounce.BounceType == "Permanent" {
bt = "permanent"
}
_ = h.notifs.MarkBounce(ctx, inner.Mail.MessageID, bt)
if h.suppress != nil && bt == "permanent" {
for _, rcp := range inner.Bounce.BouncedRecipients {
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses permanent bounce", notification.SuppressionBounce)
}
}
case "Complaint":
_ = h.notifs.MarkComplaint(ctx, inner.Mail.MessageID)
if h.suppress != nil {
for _, rcp := range inner.Complaint.ComplainedRecipients {
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses complaint", notification.SuppressionComplaint)
}
}
case "Delivery":
_ = h.notifs.MarkDelivered(ctx, inner.Mail.MessageID)
}
h.logger.Info("ses notification", "type", inner.NotificationType, "message_id", inner.Mail.MessageID)
w.WriteHeader(http.StatusNoContent)
}
// Compile-time check that ctx is unused in package — silences linter on
// some Go versions when the file would otherwise import context only for
// the handler signatures.
var _ = context.Background
+51
View File
@@ -0,0 +1,51 @@
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/alchemistkay/guestguard/internal/storage"
)
type wsTicketHandler struct {
tickets *wsTicketStore
events *storage.EventRepo
}
type wsTicketResponse struct {
Ticket string `json:"ticket"`
ExpiresAt time.Time `json:"expires_at"`
}
// POST /auth/ws-ticket — requireAuth-protected; body { "event_id": "<uuid>" }.
// Returns a single-use ticket valid for ~60 seconds. The frontend appends it
// as `?ticket=…` on the WebSocket URL.
func (h *wsTicketHandler) issue(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
var req struct {
EventID string `json:"event_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
eventID, ok := parseRawUUID(w, "event_id", req.EventID)
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
tok, exp, err := h.tickets.Mint(hostID, eventID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to mint ticket")
return
}
writeJSON(w, http.StatusOK, wsTicketResponse{Ticket: tok, ExpiresAt: exp})
}
+81
View File
@@ -0,0 +1,81 @@
package api
import (
"crypto/rand"
"encoding/base64"
"sync"
"time"
"github.com/google/uuid"
)
// wsTicketStore mints short-lived single-use tickets that authorise a
// WebSocket handshake. The plan calls this option 3 in Block B: cookies
// don't reach the WS handshake on cross-origin setups and a JWT in the URL
// would leak to logs; a one-shot ticket sidesteps both.
//
// Block B keeps this in-process. When the API runs more than one replica
// this needs to move to Redis (Block C territory).
type wsTicketStore struct {
mu sync.Mutex
entries map[string]wsTicketEntry
ttl time.Duration
now func() time.Time
}
type wsTicketEntry struct {
userID uuid.UUID
eventID uuid.UUID
expiresAt time.Time
}
func newWSTicketStore(ttl time.Duration) *wsTicketStore {
return &wsTicketStore{
entries: make(map[string]wsTicketEntry),
ttl: ttl,
now: time.Now,
}
}
// Mint returns a fresh URL-safe ticket bound to userID + eventID.
func (s *wsTicketStore) Mint(userID, eventID uuid.UUID) (string, time.Time, error) {
buf := make([]byte, 24)
if _, err := rand.Read(buf); err != nil {
return "", time.Time{}, err
}
tok := base64.RawURLEncoding.EncodeToString(buf)
exp := s.now().Add(s.ttl)
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked()
s.entries[tok] = wsTicketEntry{userID: userID, eventID: eventID, expiresAt: exp}
return tok, exp, nil
}
// Consume removes the ticket and returns the bound (userID, eventID) if it
// was valid. A ticket is single-use — replaying it fails.
func (s *wsTicketStore) Consume(token string) (uuid.UUID, uuid.UUID, bool) {
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.entries[token]
if !ok {
return uuid.Nil, uuid.Nil, false
}
delete(s.entries, token)
if s.now().After(entry.expiresAt) {
return uuid.Nil, uuid.Nil, false
}
return entry.userID, entry.eventID, true
}
// sweepLocked drops expired entries opportunistically. Cheap because we
// usually only hold dozens of tickets at a time.
func (s *wsTicketStore) sweepLocked() {
now := s.now()
for k, v := range s.entries {
if now.After(v.expiresAt) {
delete(s.entries, k)
}
}
}
+21 -1
View File
@@ -91,14 +91,34 @@ func (h *Hub) remove(eventID uuid.UUID, s *subscriber) {
type wsHandler struct {
logger *slog.Logger
hub *Hub
tickets *wsTicketStore
}
// GET /ws/events/{id} — dashboard live feed for one event.
// GET /ws/events/{id}?ticket=... — dashboard live feed for one event.
//
// The handshake is authorised by a single-use ticket minted via
// POST /auth/ws-ticket (option 3 from the Block B plan). The ticket binds
// the connecting user to a specific event_id; we reject if either is
// missing or doesn't match the URL path.
func (h *wsHandler) handle(w http.ResponseWriter, r *http.Request) {
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
rawTicket := r.URL.Query().Get("ticket")
if rawTicket == "" {
writeError(w, http.StatusUnauthorized, "missing ticket")
return
}
_, ticketEventID, valid := h.tickets.Consume(rawTicket)
if !valid {
writeError(w, http.StatusUnauthorized, "invalid or expired ticket")
return
}
if ticketEventID != eventID {
writeError(w, http.StatusForbidden, "ticket does not match event")
return
}
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
// In dev the frontend runs on a different origin (localhost:3000 → localhost:8080).
+32
View File
@@ -0,0 +1,32 @@
package auth
import (
"context"
"log/slog"
)
// EmailSender delivers transactional auth emails (verification, reset).
// Block A ships LogSender so dev environments work without Twilio/SES.
// Block D replaces this with a real SES-backed sender.
type EmailSender interface {
SendVerification(ctx context.Context, to, name, link string) error
SendPasswordReset(ctx context.Context, to, name, link string) error
}
type LogEmailSender struct {
Logger *slog.Logger
}
func (l LogEmailSender) SendVerification(_ context.Context, to, name, link string) error {
l.Logger.Info("auth email (stub): verification",
"to", to, "name", name, "link", link,
)
return nil
}
func (l LogEmailSender) SendPasswordReset(_ context.Context, to, name, link string) error {
l.Logger.Info("auth email (stub): password reset",
"to", to, "name", name, "link", link,
)
return nil
}
+95
View File
@@ -0,0 +1,95 @@
package auth
import (
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
var (
ErrInvalidJWT = errors.New("invalid token")
ErrExpiredJWT = errors.New("token expired")
)
type AccessClaims struct {
UserID uuid.UUID `json:"sub_uuid"`
jwt.RegisteredClaims
}
type JWTSigner struct {
secret []byte
ttl time.Duration
issuer string
parser *jwt.Parser
}
func NewJWTSigner(secret string, ttl time.Duration, issuer string) (*JWTSigner, error) {
if len(secret) < 32 {
return nil, fmt.Errorf("jwt secret must be at least 32 bytes")
}
if ttl <= 0 {
return nil, fmt.Errorf("jwt ttl must be positive")
}
return &JWTSigner{
secret: []byte(secret),
ttl: ttl,
issuer: issuer,
parser: jwt.NewParser(
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
jwt.WithIssuer(issuer),
jwt.WithExpirationRequired(),
),
}, nil
}
func (s *JWTSigner) Issue(userID uuid.UUID, now time.Time) (string, time.Time, error) {
exp := now.Add(s.ttl)
claims := AccessClaims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: s.issuer,
Subject: userID.String(),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)),
ExpiresAt: jwt.NewNumericDate(exp),
ID: uuid.NewString(),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := tok.SignedString(s.secret)
if err != nil {
return "", time.Time{}, err
}
return signed, exp, nil
}
func (s *JWTSigner) Parse(raw string) (*AccessClaims, error) {
claims := &AccessClaims{}
tok, err := s.parser.ParseWithClaims(raw, claims, func(t *jwt.Token) (any, error) {
return s.secret, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrExpiredJWT
}
return nil, ErrInvalidJWT
}
if !tok.Valid {
return nil, ErrInvalidJWT
}
if claims.UserID == uuid.Nil {
// Fallback for tokens that only carry Subject (defensive — we always
// set UserID on issue).
parsed, perr := uuid.Parse(claims.Subject)
if perr != nil {
return nil, ErrInvalidJWT
}
claims.UserID = parsed
}
return claims, nil
}
func (s *JWTSigner) TTL() time.Duration { return s.ttl }
+82
View File
@@ -0,0 +1,82 @@
package auth
import (
"errors"
"testing"
"time"
"github.com/google/uuid"
)
const testSecret = "test-secret-must-be-at-least-32-bytes-long-xx"
func TestJWTRoundTrip(t *testing.T) {
s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test")
if err != nil {
t.Fatalf("signer: %v", err)
}
uid := uuid.New()
tok, exp, err := s.Issue(uid, time.Now())
if err != nil {
t.Fatalf("issue: %v", err)
}
if tok == "" {
t.Fatal("empty token")
}
if time.Until(exp) <= 0 {
t.Fatalf("expiry in past: %v", exp)
}
claims, err := s.Parse(tok)
if err != nil {
t.Fatalf("parse: %v", err)
}
if claims.UserID != uid {
t.Fatalf("user mismatch: got %s want %s", claims.UserID, uid)
}
}
func TestJWTExpired(t *testing.T) {
s, err := NewJWTSigner(testSecret, 1*time.Second, "guestguard-test")
if err != nil {
t.Fatalf("signer: %v", err)
}
tok, _, err := s.Issue(uuid.New(), time.Now().Add(-1*time.Hour))
if err != nil {
t.Fatalf("issue: %v", err)
}
if _, err := s.Parse(tok); !errors.Is(err, ErrExpiredJWT) {
t.Fatalf("expected ErrExpiredJWT, got %v", err)
}
}
func TestJWTTamper(t *testing.T) {
s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test")
if err != nil {
t.Fatalf("signer: %v", err)
}
tok, _, _ := s.Issue(uuid.New(), time.Now())
// Flip a character in the signature segment.
tampered := tok[:len(tok)-1] + "a"
if tampered == tok {
tampered = tok[:len(tok)-1] + "b"
}
if _, err := s.Parse(tampered); !errors.Is(err, ErrInvalidJWT) {
t.Fatalf("expected ErrInvalidJWT, got %v", err)
}
}
func TestJWTSecretTooShort(t *testing.T) {
if _, err := NewJWTSigner("short", time.Minute, "x"); err == nil {
t.Fatal("expected error for short secret")
}
}
func TestOpaqueTokenHashStable(t *testing.T) {
raw, hash, err := NewOpaqueToken()
if err != nil {
t.Fatalf("mint: %v", err)
}
if got := HashOpaque(raw); got != hash {
t.Fatalf("hash mismatch: got %s want %s", got, hash)
}
}
+107
View File
@@ -0,0 +1,107 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
// LockoutTracker counts consecutive failed logins per email and trips a
// lockout flag after a threshold. The lock is keyed by user_id (once we
// know it from the email) so that resetting the password — which we do via
// /auth/reset-password — can clear it cleanly.
//
// Why two keys? The failure counter must work even when the email maps to
// no user (otherwise an attacker probing addresses just gets unlimited
// tries). The lock flag only exists once we've matched an actual account.
type LockoutTracker struct {
client *redis.Client
threshold int
window time.Duration // how long failures linger before counters reset
prefix string
}
func NewLockoutTracker(client *redis.Client, threshold int, window time.Duration) *LockoutTracker {
return &LockoutTracker{
client: client,
threshold: threshold,
window: window,
prefix: "auth",
}
}
func (t *LockoutTracker) failKey(email string) string {
h := sha256.Sum256([]byte(strings.ToLower(strings.TrimSpace(email))))
return fmt.Sprintf("%s:login_fail:%s", t.prefix, hex.EncodeToString(h[:]))
}
func (t *LockoutTracker) lockKey(uid uuid.UUID) string {
return fmt.Sprintf("%s:locked:%s", t.prefix, uid.String())
}
// IsLocked reports whether the given user's account is currently locked.
func (t *LockoutTracker) IsLocked(ctx context.Context, uid uuid.UUID) (bool, error) {
if t == nil || t.client == nil {
return false, nil
}
v, err := t.client.Exists(ctx, t.lockKey(uid)).Result()
if err != nil {
return false, err
}
return v > 0, nil
}
// RecordFailure increments the failure counter for the email and, if it
// crosses the threshold, sets the lock flag for the given user id.
// Returns (locked, error). A nil userID is fine — the counter still ticks
// up so probing nonexistent accounts is also rate-limited.
func (t *LockoutTracker) RecordFailure(ctx context.Context, email string, userID *uuid.UUID) (bool, error) {
if t == nil || t.client == nil {
return false, nil
}
key := t.failKey(email)
n, err := t.client.Incr(ctx, key).Result()
if err != nil {
return false, err
}
if n == 1 {
_ = t.client.Expire(ctx, key, t.window).Err()
}
if int(n) >= t.threshold && userID != nil {
// Keep the lock until password reset clears it. 7-day fallback TTL
// so a permanently abandoned account doesn't pile up forever.
if err := t.client.Set(ctx, t.lockKey(*userID), "1", 7*24*time.Hour).Err(); err != nil {
return false, err
}
return true, nil
}
return false, nil
}
// ClearForUser drops both the lock flag and any in-flight failure counter
// for the user's email. Called from /auth/reset-password.
func (t *LockoutTracker) ClearForUser(ctx context.Context, uid uuid.UUID, email string) error {
if t == nil || t.client == nil {
return nil
}
pipe := t.client.Pipeline()
pipe.Del(ctx, t.lockKey(uid))
pipe.Del(ctx, t.failKey(email))
_, err := pipe.Exec(ctx)
return err
}
// ClearOnSuccess drops only the failure counter — used after a successful
// login to forgive prior typos.
func (t *LockoutTracker) ClearOnSuccess(ctx context.Context, email string) {
if t == nil || t.client == nil {
return
}
_ = t.client.Del(ctx, t.failKey(email)).Err()
}
+58
View File
@@ -0,0 +1,58 @@
package auth
import (
"errors"
"golang.org/x/crypto/bcrypt"
)
const bcryptCost = 12
var (
ErrPasswordMismatch = errors.New("password mismatch")
ErrPasswordTooShort = errors.New("password must be at least 8 characters")
ErrPasswordTooLong = errors.New("password must be at most 72 characters")
)
type PasswordHasher struct {
cost int
}
func NewPasswordHasher() *PasswordHasher {
return &PasswordHasher{cost: bcryptCost}
}
func (h *PasswordHasher) Hash(raw string) (string, error) {
if err := ValidatePassword(raw); err != nil {
return "", err
}
b, err := bcrypt.GenerateFromPassword([]byte(raw), h.cost)
if err != nil {
return "", err
}
return string(b), nil
}
func (h *PasswordHasher) Verify(hash, raw string) error {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(raw)); err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return ErrPasswordMismatch
}
return err
}
return nil
}
// ValidatePassword enforces a minimum length and the bcrypt-imposed maximum.
// bcrypt silently truncates inputs over 72 bytes, which would let a user set
// a 100-character password and successfully log in with the first 72; reject
// at the boundary instead.
func ValidatePassword(raw string) error {
if len(raw) < 8 {
return ErrPasswordTooShort
}
if len(raw) > 72 {
return ErrPasswordTooLong
}
return nil
}
+44
View File
@@ -0,0 +1,44 @@
package auth
import (
"errors"
"strings"
"testing"
)
func TestPasswordHasherRoundTrip(t *testing.T) {
h := NewPasswordHasher()
hash, err := h.Hash("correct-horse-battery-staple")
if err != nil {
t.Fatalf("hash: %v", err)
}
if hash == "" {
t.Fatal("empty hash")
}
if err := h.Verify(hash, "correct-horse-battery-staple"); err != nil {
t.Fatalf("verify correct: %v", err)
}
if err := h.Verify(hash, "wrong"); !errors.Is(err, ErrPasswordMismatch) {
t.Fatalf("verify wrong: got %v want ErrPasswordMismatch", err)
}
}
func TestPasswordValidation(t *testing.T) {
tests := []struct {
name string
pw string
want error
}{
{"too short", "1234567", ErrPasswordTooShort},
{"min length ok", "12345678", nil},
{"too long", strings.Repeat("a", 73), ErrPasswordTooLong},
{"max length ok", strings.Repeat("a", 72), nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := ValidatePassword(tc.pw); !errors.Is(got, tc.want) {
t.Fatalf("got %v want %v", got, tc.want)
}
})
}
}
+26
View File
@@ -0,0 +1,26 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
)
// NewOpaqueToken returns a 32-byte URL-safe random token plus its SHA-256 hex
// digest. The raw value is shown once (in a link); only the digest is stored.
func NewOpaqueToken() (raw, hash string, err error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", "", err
}
raw = base64.RawURLEncoding.EncodeToString(buf)
sum := sha256.Sum256([]byte(raw))
hash = hex.EncodeToString(sum[:])
return raw, hash, nil
}
func HashOpaque(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
+157
View File
@@ -0,0 +1,157 @@
package billing
import (
"errors"
"fmt"
"github.com/stripe/stripe-go/v82"
"github.com/stripe/stripe-go/v82/billingportal/session"
csession "github.com/stripe/stripe-go/v82/checkout/session"
"github.com/stripe/stripe-go/v82/customer"
"github.com/stripe/stripe-go/v82/webhook"
)
// Client wraps the Stripe SDK with the subset of calls the API needs.
// Concentrates env-var reads + price-ID lookups in one place so handlers
// stay focused on HTTP, not Stripe plumbing.
type Client struct {
secretKey string
webhookSecret string
prices map[Tier]string
}
// Config is the env-derived configuration the Client needs.
type Config struct {
SecretKey string
WebhookSecret string
PriceProMonthly string
PriceBusiness string
}
// NewClient validates required fields and returns a configured client.
// Returns (nil, nil) when SecretKey is empty — callers treat that as
// "billing disabled" and degrade gracefully (free tier for everyone, no
// /billing/* endpoints exposed).
func NewClient(cfg Config) (*Client, error) {
if cfg.SecretKey == "" {
return nil, nil
}
stripe.Key = cfg.SecretKey
c := &Client{
secretKey: cfg.SecretKey,
webhookSecret: cfg.WebhookSecret,
prices: map[Tier]string{
TierPro: cfg.PriceProMonthly,
TierBusiness: cfg.PriceBusiness,
},
}
return c, nil
}
// Enabled reports whether the client was constructed with a Stripe key.
func (c *Client) Enabled() bool { return c != nil && c.secretKey != "" }
// PriceFor returns the Stripe Price ID for a tier or an error if it
// hasn't been configured — checkout will fail loudly rather than
// silently send the customer to an empty checkout page.
func (c *Client) PriceFor(tier Tier) (string, error) {
id, ok := c.prices[tier]
if !ok || id == "" {
return "", fmt.Errorf("billing: no Stripe price configured for tier %q", tier)
}
return id, nil
}
// CreateOrGetCustomer returns the Stripe customer id for the given
// (user_id, email). If `existingID` is non-empty we trust it; otherwise
// we create a new Stripe customer and let the caller persist the id.
func (c *Client) CreateOrGetCustomer(userID, email, name, existingID string) (string, error) {
if !c.Enabled() {
return "", errors.New("billing: disabled")
}
if existingID != "" {
return existingID, nil
}
params := &stripe.CustomerParams{
Email: stripe.String(email),
Name: stripe.String(name),
Metadata: map[string]string{"gg_user_id": userID},
}
cust, err := customer.New(params)
if err != nil {
return "", fmt.Errorf("stripe customer create: %w", err)
}
return cust.ID, nil
}
// CheckoutSessionParams collects the inputs CreateCheckoutSession needs.
// Keeping them in a struct so future fields (coupon codes, trial periods,
// referral metadata) drop in without breaking callers.
type CheckoutSessionParams struct {
CustomerID string
PriceID string
SuccessURL string
CancelURL string
}
// CreateCheckoutSession returns the URL the frontend redirects the user
// to. Subscription mode — recurring billing for Pro/Business.
func (c *Client) CreateCheckoutSession(p CheckoutSessionParams) (string, error) {
if !c.Enabled() {
return "", errors.New("billing: disabled")
}
params := &stripe.CheckoutSessionParams{
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
Customer: stripe.String(p.CustomerID),
SuccessURL: stripe.String(p.SuccessURL),
CancelURL: stripe.String(p.CancelURL),
LineItems: []*stripe.CheckoutSessionLineItemParams{{
Price: stripe.String(p.PriceID),
Quantity: stripe.Int64(1),
}},
AllowPromotionCodes: stripe.Bool(true),
}
sess, err := csession.New(params)
if err != nil {
return "", fmt.Errorf("stripe checkout session: %w", err)
}
return sess.URL, nil
}
// CreatePortalSession returns a URL to the Stripe-hosted customer
// portal so the user can manage payment methods, cancel, view invoices.
func (c *Client) CreatePortalSession(customerID, returnURL string) (string, error) {
if !c.Enabled() {
return "", errors.New("billing: disabled")
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(customerID),
ReturnURL: stripe.String(returnURL),
}
sess, err := session.New(params)
if err != nil {
return "", fmt.Errorf("stripe portal session: %w", err)
}
return sess.URL, nil
}
// VerifyWebhook validates the Stripe signature header and returns the
// parsed event. Refuses to verify when no webhook secret is configured —
// no shared secret means anyone can POST forged events, so the route
// should reject everything in that case (the caller does that check).
//
// We pass IgnoreAPIVersionMismatch because Stripe accounts can be on a
// newer API version than the SDK we're built against. Event payloads
// are designed to be forward-compatible — the SDK warns about the skew
// but the deserialised event is still safe to use. Strict matching
// would mean we'd have to upgrade the SDK in lockstep with whatever
// Stripe rolls out, defeating the point of having an SDK.
func (c *Client) VerifyWebhook(body []byte, sigHeader string) (stripe.Event, error) {
if c.webhookSecret == "" {
return stripe.Event{}, errors.New("billing: no webhook secret configured")
}
return webhook.ConstructEventWithOptions(body, sigHeader, c.webhookSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
}
+73
View File
@@ -0,0 +1,73 @@
// Package billing models GuestGuard's subscription tiers and Stripe
// integration. Plan limits live here so handler + middleware layers don't
// hard-code numbers — change a value, restart the API, and the cap moves.
package billing
import "fmt"
// Tier is the user's current subscription tier. Stored as text in the
// subscriptions table.
type Tier string
const (
TierFree Tier = "free"
TierPro Tier = "pro"
TierBusiness Tier = "business"
)
func (t Tier) Valid() bool {
switch t {
case TierFree, TierPro, TierBusiness:
return true
}
return false
}
// Limits enforces what a tier may do. -1 means unlimited.
type Limits struct {
EventsPerMonth int
GuestsPerEvent int
}
// TierLimits is the canonical plan-limits table. Matches docs/TIER1_PLAN.md
// Block F pricing (placeholder until market validation).
var TierLimits = map[Tier]Limits{
TierFree: {EventsPerMonth: 1, GuestsPerEvent: 50},
TierPro: {EventsPerMonth: 10, GuestsPerEvent: 1000},
TierBusiness: {EventsPerMonth: -1, GuestsPerEvent: 5000},
}
// LimitsFor returns the limits for a tier, defaulting to Free for unknown
// strings — defensive, so a typo or future-tier in the DB never grants
// unlimited access.
func LimitsFor(t Tier) Limits {
if l, ok := TierLimits[t]; ok {
return l
}
return TierLimits[TierFree]
}
// StatusGrantsAccess returns true if a subscription with this status
// should be treated as paid. Mirrors the unique-index predicate in
// 0005_billing.up.sql.
func StatusGrantsAccess(status string) bool {
switch status {
case "active", "past_due", "trialing":
return true
}
return false
}
// LimitError describes a denied-by-policy outcome. The handler layer
// turns this into a 402 with a JSON body the frontend uses to render
// the upgrade modal.
type LimitError struct {
Reason string
Tier Tier
Limit int
Used int
}
func (e *LimitError) Error() string {
return fmt.Sprintf("billing: %s (tier=%s used=%d limit=%d)", e.Reason, e.Tier, e.Used, e.Limit)
}
+127
View File
@@ -12,11 +12,56 @@ type Config struct {
HTTPAddr string
DatabaseURL string
NATSURL string
RedisAddr string
FraudGRPCAddr string
FraudGRPCTimeout time.Duration
ShutdownTimeout time.Duration
TokenSecret string
TokenTTL time.Duration
// Auth
JWTSecret string
JWTIssuer string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
EmailVerificationTTL time.Duration
PasswordResetTTL time.Duration
PublicBaseURL string
RefreshCookieDomain string
RefreshCookieSecure bool
// Notifications (Block D). Empty values leave the log-stub adapter in
// place so the service still boots without AWS / Twilio creds.
SESRegion string
SESFromEmail string
SESFromName string
SESConfigurationSet string
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFromEmail string
SMTPFromName string
SMTPTLS string // "starttls" | "implicit" | "none" (default starttls)
ResendAPIKey string
ResendFromEmail string
ResendFromName string
TwilioAccountSID string
TwilioAuthToken string
TwilioFromNumber string
// Billing (Block F). Empty StripeSecretKey leaves billing disabled —
// all users get free-tier limits, /billing/* returns 503. Lets the
// service boot in dev without Stripe creds.
StripeSecretKey string
StripeWebhookSecret string
StripePricePro string // Stripe Price ID for Pro monthly
StripePriceBusiness string // Stripe Price ID for Business monthly
UnsubscribeSecret string // HMAC key for signing unsubscribe links
}
func Load() (*Config, error) {
@@ -25,11 +70,50 @@ func Load() (*Config, error) {
HTTPAddr: getenv("GG_HTTP_ADDR", ":8080"),
DatabaseURL: getenv("GG_DATABASE_URL", "postgres://guestguard:guestguard@localhost:5432/guestguard?sslmode=disable"),
NATSURL: getenv("GG_NATS_URL", "nats://localhost:4222"),
RedisAddr: getenv("GG_REDIS_ADDR", "localhost:6379"),
FraudGRPCAddr: getenv("GG_FRAUD_GRPC_ADDR", "fraud-engine:9091"),
FraudGRPCTimeout: getenvDuration("GG_FRAUD_GRPC_TIMEOUT", 250*time.Millisecond),
ShutdownTimeout: getenvDuration("GG_SHUTDOWN_TIMEOUT", 15*time.Second),
TokenSecret: os.Getenv("GG_TOKEN_SECRET"),
TokenTTL: getenvDuration("GG_TOKEN_TTL", 30*24*time.Hour),
JWTSecret: os.Getenv("GG_JWT_SECRET"),
JWTIssuer: getenv("GG_JWT_ISSUER", "guestguard"),
AccessTokenTTL: getenvDuration("GG_ACCESS_TOKEN_TTL", 15*time.Minute),
RefreshTokenTTL: getenvDuration("GG_REFRESH_TOKEN_TTL", 30*24*time.Hour),
EmailVerificationTTL: getenvDuration("GG_EMAIL_VERIFICATION_TTL", 24*time.Hour),
PasswordResetTTL: getenvDuration("GG_PASSWORD_RESET_TTL", 1*time.Hour),
PublicBaseURL: getenv("GG_PUBLIC_BASE_URL", "http://localhost:3000"),
RefreshCookieDomain: os.Getenv("GG_REFRESH_COOKIE_DOMAIN"),
RefreshCookieSecure: getenvBool("GG_REFRESH_COOKIE_SECURE", false),
SESRegion: getenv("GG_SES_REGION", "us-east-1"),
SESFromEmail: os.Getenv("GG_SES_FROM_EMAIL"),
SESFromName: getenv("GG_SES_FROM_NAME", "GuestGuard"),
SESConfigurationSet: os.Getenv("GG_SES_CONFIGURATION_SET"),
SMTPHost: os.Getenv("GG_SMTP_HOST"),
SMTPPort: getenvInt("GG_SMTP_PORT", 587),
SMTPUsername: os.Getenv("GG_SMTP_USERNAME"),
SMTPPassword: os.Getenv("GG_SMTP_PASSWORD"),
SMTPFromEmail: os.Getenv("GG_SMTP_FROM_EMAIL"),
SMTPFromName: getenv("GG_SMTP_FROM_NAME", "GuestGuard"),
SMTPTLS: getenv("GG_SMTP_TLS", "starttls"),
ResendAPIKey: os.Getenv("GG_RESEND_API_KEY"),
ResendFromEmail: os.Getenv("GG_RESEND_FROM_EMAIL"),
ResendFromName: getenv("GG_RESEND_FROM_NAME", "GuestGuard"),
TwilioAccountSID: os.Getenv("GG_TWILIO_ACCOUNT_SID"),
TwilioAuthToken: os.Getenv("GG_TWILIO_AUTH_TOKEN"),
TwilioFromNumber: os.Getenv("GG_TWILIO_FROM_NUMBER"),
StripeSecretKey: os.Getenv("GG_STRIPE_SECRET_KEY"),
StripeWebhookSecret: os.Getenv("GG_STRIPE_WEBHOOK_SECRET"),
StripePricePro: os.Getenv("GG_STRIPE_PRICE_PRO"),
StripePriceBusiness: os.Getenv("GG_STRIPE_PRICE_BUSINESS"),
UnsubscribeSecret: os.Getenv("GG_UNSUBSCRIBE_SECRET"),
}
if cfg.Env == "production" && cfg.TokenSecret == "" {
@@ -39,9 +123,52 @@ func Load() (*Config, error) {
cfg.TokenSecret = "dev-only-insecure-secret-change-me"
}
if cfg.Env == "production" && cfg.JWTSecret == "" {
return nil, fmt.Errorf("GG_JWT_SECRET is required in production")
}
if cfg.JWTSecret == "" {
cfg.JWTSecret = "dev-only-insecure-jwt-secret-change-me-32+bytes"
}
if len(cfg.JWTSecret) < 32 {
return nil, fmt.Errorf("GG_JWT_SECRET must be at least 32 bytes")
}
if cfg.UnsubscribeSecret == "" {
// Same dev fallback shape as the other secrets — production refuses
// to boot without it.
if cfg.Env == "production" {
return nil, fmt.Errorf("GG_UNSUBSCRIBE_SECRET is required in production")
}
cfg.UnsubscribeSecret = "dev-only-insecure-unsubscribe-secret-change-me"
}
return cfg, nil
}
func getenvInt(key string, fallback int) int {
v, ok := os.LookupEnv(key)
if !ok || v == "" {
return fallback
}
n, err := strconv.Atoi(v)
if err != nil {
return fallback
}
return n
}
func getenvBool(key string, fallback bool) bool {
v, ok := os.LookupEnv(key)
if !ok || v == "" {
return fallback
}
b, err := strconv.ParseBool(v)
if err != nil {
return fallback
}
return b
}
func getenv(key, fallback string) string {
if v, ok := os.LookupEnv(key); ok && v != "" {
return v
+250
View File
@@ -0,0 +1,250 @@
// Package csvimport parses a guest-list CSV into structured rows, with
// tolerant header detection (Excel, Numbers, Google Sheets variants) and
// per-row validation. Streaming-friendly so a 5,000-row import doesn't
// load the entire file into a slice before we know if column 1 is junk.
package csvimport
import (
"bufio"
"encoding/csv"
"errors"
"fmt"
"io"
"net/mail"
"regexp"
"strconv"
"strings"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
)
// Row is a single validated guest. Empty Email / Phone are allowed (a
// phone-only or name-only guest is valid per the plan).
type Row struct {
Name string
Email string
Phone string
PlusOnes int
}
// RowError flags one row with the human-readable reason it can't be
// imported. The line number is 1-based and matches the source CSV
// (header counts as line 1, first data row is line 2) so the frontend
// can highlight the offending row.
type RowError struct {
Row int `json:"row"`
Reason string `json:"reason"`
}
// Result is the outcome of one parse pass.
type Result struct {
Rows []Row `json:"rows,omitempty"`
Errors []RowError `json:"errors,omitempty"`
TotalCount int `json:"total_count"` // total data rows seen (excluding header)
}
// Options tune limits + behaviour.
type Options struct {
MaxRows int // hard cap; rows beyond MaxRows return an error instead of being silently dropped
}
const DefaultMaxRows = 5000
// Strict E.164: optional leading +, then a non-zero leading digit (country
// codes never start with 0), followed by 614 more digits — total 715
// significant digits. Spaces / dashes / parens are tolerated by stripping
// before validation, but local-format numbers like "0244…" or "07700…"
// are rejected here so the host fixes them at upload time rather than at
// WhatsApp-send time.
var phoneRe = regexp.MustCompile(`^\+?[1-9][0-9]{6,14}$`)
// Parse reads a CSV from r and returns the parsed result. Encoding is
// auto-detected: UTF-8 with or without BOM, plus UTF-16 LE/BE BOMs
// (commonly produced by Mac Numbers exports).
func Parse(r io.Reader, opt Options) (*Result, error) {
max := opt.MaxRows
if max <= 0 {
max = DefaultMaxRows
}
rd, err := decodingReader(r)
if err != nil {
return nil, err
}
csvr := csv.NewReader(rd)
csvr.FieldsPerRecord = -1 // tolerate ragged rows; we re-validate column count ourselves
csvr.TrimLeadingSpace = true
header, err := csvr.Read()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("csv is empty")
}
return nil, fmt.Errorf("read header: %w", err)
}
cols, err := detectColumns(header)
if err != nil {
return nil, err
}
out := &Result{Rows: make([]Row, 0, 64)}
lineNo := 1 // header was line 1
for {
rec, err := csvr.Read()
if err == io.EOF {
break
}
lineNo++
if err != nil {
out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: fmt.Sprintf("malformed csv: %v", err)})
continue
}
out.TotalCount++
if out.TotalCount > max {
return nil, fmt.Errorf("import exceeds maximum of %d rows", max)
}
// Skip fully-empty rows silently — these appear at the end of
// Excel exports a lot.
if rowEmpty(rec) {
out.TotalCount-- // don't count it
continue
}
row, rerr := buildRow(rec, cols)
if rerr != "" {
out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: rerr})
continue
}
out.Rows = append(out.Rows, row)
}
return out, nil
}
func rowEmpty(rec []string) bool {
for _, v := range rec {
if strings.TrimSpace(v) != "" {
return false
}
}
return true
}
// decodingReader strips a UTF-8 BOM and decodes UTF-16 LE/BE when their
// BOM is present, returning a UTF-8 reader. Other byte orders fall through
// as raw UTF-8.
func decodingReader(r io.Reader) (*bufio.Reader, error) {
br := bufio.NewReader(r)
bom, err := br.Peek(3)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
switch {
case len(bom) >= 3 && bom[0] == 0xEF && bom[1] == 0xBB && bom[2] == 0xBF:
_, _ = br.Discard(3)
return br, nil
case len(bom) >= 2 && bom[0] == 0xFF && bom[1] == 0xFE:
_, _ = br.Discard(2)
dec := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()
return bufio.NewReader(transform.NewReader(br, dec)), nil
case len(bom) >= 2 && bom[0] == 0xFE && bom[1] == 0xFF:
_, _ = br.Discard(2)
dec := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder()
return bufio.NewReader(transform.NewReader(br, dec)), nil
}
return br, nil
}
// columnSet records which column index each known field lives in. -1 means
// the column was not supplied; only Name is mandatory.
type columnSet struct {
name, email, phone, plusOnes int
}
func detectColumns(header []string) (columnSet, error) {
cs := columnSet{name: -1, email: -1, phone: -1, plusOnes: -1}
for i, raw := range header {
key := normaliseHeader(raw)
switch key {
case "name", "guestname", "fullname":
cs.name = i
case "email", "emailaddress", "e-mail":
cs.email = i
case "phone", "telephone", "mobile", "phonenumber":
cs.phone = i
case "plusones", "plus1", "plus-one", "plus-ones", "+1", "guests", "additionalguests":
cs.plusOnes = i
}
}
if cs.name < 0 {
return cs, fmt.Errorf("required column 'name' not found in header: %v", header)
}
return cs, nil
}
func normaliseHeader(s string) string {
s = strings.ToLower(strings.TrimSpace(s))
// Drop spaces + underscores. Keep `+`, `-` so "+1" / "plus-one" still
// match exactly.
return strings.NewReplacer(" ", "", "_", "").Replace(s)
}
func buildRow(rec []string, cs columnSet) (Row, string) {
get := func(i int) string {
if i < 0 || i >= len(rec) {
return ""
}
return strings.TrimSpace(rec[i])
}
row := Row{
Name: get(cs.name),
Email: strings.ToLower(get(cs.email)),
Phone: get(cs.phone),
}
if row.Name == "" {
return row, "name is required"
}
if row.Email != "" {
if _, err := mail.ParseAddress(row.Email); err != nil {
return row, "invalid email"
}
}
if row.Phone != "" {
stripped := stripPhone(row.Phone)
if !phoneRe.MatchString(stripped) {
return row, "phone must be in international format with country code (e.g. +447700900123) — local numbers starting with 0 won't work for SMS or WhatsApp"
}
// Normalise: ensure stored form always starts with "+".
if !strings.HasPrefix(stripped, "+") {
stripped = "+" + stripped
}
row.Phone = stripped
}
if raw := get(cs.plusOnes); raw != "" {
n, err := strconv.Atoi(raw)
if err != nil || n < 0 {
return row, "plus_ones must be a non-negative integer"
}
row.PlusOnes = n
}
return row, ""
}
var phoneStripper = strings.NewReplacer(" ", "", "-", "", "(", "", ")", "", " ", "")
func stripPhone(s string) string {
return phoneStripper.Replace(s)
}
// TemplateCSV is the sample file served at /events/{id}/guests/import/template.
// Phone numbers MUST include the country code (e.g. +44 for UK, +233 for
// Ghana). Local-format numbers like "0244..." or "07700..." will be
// rejected at upload — the sample below shows the expected shape.
const TemplateCSV = "name,email,phone,plus_ones\n" +
"Alex Doe,alex@example.com,+447700900123,1\n" +
"Sam Patel,sam@example.com,,0\n" +
"Jordan Lee,,+15551234567,2\n" +
"Mira Patel,mira@example.com,+233244123456,0\n"
+157
View File
@@ -0,0 +1,157 @@
package csvimport
import (
"strings"
"testing"
)
func TestParseHappyPath(t *testing.T) {
in := `name,email,phone,plus_ones
Alex Doe,alex@example.com,+447700900123,1
Sam Patel,SAM@example.com,,0
Jordan Lee,,+1 (555) 123-4567,2
`
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if got, want := len(r.Rows), 3; got != want {
t.Fatalf("rows: got %d want %d (errors=%+v)", got, want, r.Errors)
}
if r.Rows[1].Email != "sam@example.com" {
t.Errorf("email not lowercased: %q", r.Rows[1].Email)
}
if r.Rows[2].Phone != "+15551234567" {
t.Errorf("phone not stripped: %q", r.Rows[2].Phone)
}
if r.Rows[0].Phone != "+447700900123" {
t.Errorf("phone should keep leading +: %q", r.Rows[0].Phone)
}
}
func TestParsePhoneNormalisedToPlus(t *testing.T) {
// E.164 without explicit "+" is accepted and normalised to include one.
in := "name,phone\nAlex,447700900123\n"
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(r.Rows) != 1 || r.Rows[0].Phone != "+447700900123" {
t.Fatalf("expected normalised phone, got: rows=%+v errors=%+v", r.Rows, r.Errors)
}
}
func TestParsePhoneRejectsLocalFormat(t *testing.T) {
// Local UK / GH style numbers (leading 0, no country code) must be
// rejected — they break SMS routing and WhatsApp click-to-chat.
in := `name,phone
UK Local,07700900123
GH Local,0244123456
With Plus And Zero,+0244123456
`
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(r.Errors) != 3 {
t.Fatalf("expected 3 errors for local-format phones, got %d: %+v", len(r.Errors), r.Errors)
}
}
func TestParseHeaderVariants(t *testing.T) {
cases := []string{
"Name,Email,Phone,Plus Ones\nMira,m@x.com,,1\n",
"Guest Name,E-Mail,Telephone,+1\nMira,m@x.com,,1\n",
"full_name,email_address,mobile,plusones\nMira,m@x.com,,1\n",
}
for i, in := range cases {
t.Run(string(rune('a'+i)), func(t *testing.T) {
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(r.Rows) != 1 {
t.Fatalf("rows=%d errors=%+v", len(r.Rows), r.Errors)
}
if r.Rows[0].PlusOnes != 1 {
t.Errorf("plusones not detected: %+v", r.Rows[0])
}
})
}
}
func TestParseRowValidation(t *testing.T) {
in := `name,email,phone,plus_ones
,a@x.com,,1
Valid Guest,not-an-email,,0
Phone Person,,abc,0
Negative,,,-1
Email Only,e@x.com,,
`
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(r.Errors) != 4 {
t.Fatalf("expected 4 errors, got %d: %+v", len(r.Errors), r.Errors)
}
// "Email Only" row is valid: name present, email parses, phone+plus blank.
if len(r.Rows) != 1 || r.Rows[0].Name != "Email Only" {
t.Fatalf("expected 1 valid row, got %+v", r.Rows)
}
}
func TestParseUTF8BOM(t *testing.T) {
in := "\xEF\xBB\xBFname,email\nMira,m@x.com\n"
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" {
t.Fatalf("BOM not stripped: %+v", r.Rows)
}
}
func TestParseUTF16LE(t *testing.T) {
// "name,email\nMira,m@x.com\n" in UTF-16 LE with BOM.
in := []byte{0xFF, 0xFE}
for _, r := range "name,email\nMira,m@x.com\n" {
in = append(in, byte(r), byte(r>>8))
}
r, err := Parse(strings.NewReader(string(in)), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" {
t.Fatalf("UTF-16 LE not decoded: %+v", r.Rows)
}
}
func TestParseEmptyTrailingRows(t *testing.T) {
in := "name,email\nMira,m@x.com\n,,\n\n,\n"
r, err := Parse(strings.NewReader(in), Options{})
if err != nil {
t.Fatalf("parse: %v", err)
}
if r.TotalCount != 1 {
t.Fatalf("trailing blanks counted: TotalCount=%d", r.TotalCount)
}
}
func TestParseMissingNameHeader(t *testing.T) {
in := "email,phone\na@x.com,\n"
if _, err := Parse(strings.NewReader(in), Options{}); err == nil {
t.Fatal("expected error for missing name column")
}
}
func TestParseMaxRows(t *testing.T) {
var b strings.Builder
b.WriteString("name\n")
for i := 0; i < 11; i++ {
b.WriteString("X\n")
}
if _, err := Parse(strings.NewReader(b.String()), Options{MaxRows: 10}); err == nil {
t.Fatal("expected error when exceeding MaxRows")
}
}
+19
View File
@@ -11,11 +11,30 @@ type User struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
PasswordHash string `json:"-"`
EmailVerified bool `json:"email_verified"`
EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
DeletedAt *time.Time `json:"-"`
TermsAcceptedAt *time.Time `json:"terms_accepted_at,omitempty"`
PrivacyPolicyAcceptedAt *time.Time `json:"privacy_policy_accepted_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TermsAccepted reports whether the user has accepted both the terms
// of service and the privacy policy. Both must be present for the user
// to use the dashboard once enforcement is enabled.
func (u *User) TermsAccepted() bool {
return u != nil && u.TermsAcceptedAt != nil && u.PrivacyPolicyAcceptedAt != nil
}
var (
ErrUserNotFound = errors.New("user not found")
ErrEmailTaken = errors.New("email already in use")
ErrEmailNotVerified = errors.New("email not verified")
ErrAuthTokenNotFound = errors.New("auth token not found")
ErrAuthTokenConsumed = errors.New("auth token already used")
ErrAuthTokenExpired = errors.New("auth token expired")
ErrRefreshTokenRevoked = errors.New("refresh token revoked")
ErrAccountLocked = errors.New("account locked due to too many failed login attempts")
)
+9
View File
@@ -92,6 +92,15 @@ func (c *Client) PublishRSVPConfirmed(ctx context.Context, evt RSVPConfirmed) er
return c.publishJSON(ctx, SubjectRSVPConfirmed, evt, evt.RSVPID)
}
func (c *Client) PublishInvitationSend(ctx context.Context, evt InvitationSend) error {
if evt.IssuedAt.IsZero() {
evt.IssuedAt = time.Now().UTC()
}
// Dedup by token id — re-issuing a token (currently disallowed by the
// unique constraint, but defensive) won't double-send the email.
return c.publishJSON(ctx, SubjectInvitationSend, evt, evt.TokenID)
}
func (c *Client) publishJSON(ctx context.Context, subject string, payload any, dedupeKey uuid.UUID) error {
body, err := json.Marshal(payload)
if err != nil {
+17
View File
@@ -38,3 +38,20 @@ type RSVPConfirmed struct {
RiskScore *int `json:"risk_score,omitempty"`
SubmittedAt time.Time `json:"submitted_at"`
}
// InvitationSend asks the notifier to dispatch a guest invitation email.
// Carries everything the email template needs so the worker doesn't have
// to re-fetch event/guest details from Postgres on every send.
type InvitationSend struct {
EventID uuid.UUID `json:"event_id"`
GuestID uuid.UUID `json:"guest_id"`
TokenID uuid.UUID `json:"token_id"`
GuestName string `json:"guest_name"`
GuestEmail string `json:"guest_email"`
HostName string `json:"host_name"`
EventName string `json:"event_name"`
Venue string `json:"venue,omitempty"`
EventDate time.Time `json:"event_date"`
Link string `json:"link"`
IssuedAt time.Time `json:"issued_at"`
}
+64
View File
@@ -0,0 +1,64 @@
package natspub
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/nats-io/nats.go/jetstream"
)
type InvitationSendHandler func(ctx context.Context, evt InvitationSend) error
type InvitationSendSubscriber struct {
logger *slog.Logger
consumer jetstream.Consumer
handler InvitationSendHandler
}
func NewInvitationSendSubscriber(
ctx context.Context,
c *Client,
durable string,
handler InvitationSendHandler,
logger *slog.Logger,
) (*InvitationSendSubscriber, error) {
cons, err := c.js.CreateOrUpdateConsumer(ctx, StreamName, jetstream.ConsumerConfig{
Durable: durable,
Name: durable,
FilterSubject: SubjectInvitationSend,
AckPolicy: jetstream.AckExplicitPolicy,
DeliverPolicy: jetstream.DeliverAllPolicy,
MaxDeliver: 5,
AckWait: 30 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("create consumer %s: %w", durable, err)
}
return &InvitationSendSubscriber{logger: logger, consumer: cons, handler: handler}, nil
}
func (s *InvitationSendSubscriber) Start(ctx context.Context) (jetstream.ConsumeContext, error) {
cc, err := s.consumer.Consume(func(msg jetstream.Msg) {
var evt InvitationSend
if err := json.Unmarshal(msg.Data(), &evt); err != nil {
s.logger.Error("decode invitation.send", "err", err)
_ = msg.Term()
return
}
hctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := s.handler(hctx, evt); err != nil {
s.logger.Error("handle invitation.send", "err", err)
_ = msg.NakWithDelay(5 * time.Second)
return
}
_ = msg.Ack()
})
if err != nil {
return nil, fmt.Errorf("consume: %w", err)
}
return cc, nil
}
+137
View File
@@ -0,0 +1,137 @@
package notification
import (
"context"
"fmt"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sesv2"
"github.com/aws/aws-sdk-go-v2/service/sesv2/types"
)
// SESConfig is the surface area for picking an SES sender. ConfigurationSet
// is optional but recommended in production — it's where bounce + complaint
// SNS topics get wired so the webhook handler has something to consume.
type SESConfig struct {
Region string
FromEmail string
FromName string
ConfigurationSet string
PublicBaseURL string // for unsubscribe links in templates
}
// SESEmailSender sends transactional emails (verification + reset for the
// auth flows, plus invitation/confirmation/reminder for guests) via Amazon
// SESv2. The same client serves both audiences so callers don't end up
// with two SES configurations to maintain.
type SESEmailSender struct {
client *sesv2.Client
tpls *Templates
from string
configSet *string
baseURL string
}
// NewSESEmailSender returns a configured SES sender, or an error if the
// AWS SDK can't bootstrap. The caller typically constructs this once at
// startup and reuses it for the lifetime of the process.
func NewSESEmailSender(ctx context.Context, cfg SESConfig, tpls *Templates) (*SESEmailSender, error) {
if cfg.FromEmail == "" {
return nil, fmt.Errorf("ses: FromEmail required")
}
if cfg.Region == "" {
cfg.Region = "us-east-1"
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(cfg.Region))
if err != nil {
return nil, fmt.Errorf("ses: load aws config: %w", err)
}
client := sesv2.NewFromConfig(awsCfg)
from := cfg.FromEmail
if cfg.FromName != "" {
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
}
var cs *string
if cfg.ConfigurationSet != "" {
cs = aws.String(cfg.ConfigurationSet)
}
return &SESEmailSender{
client: client,
tpls: tpls,
from: from,
configSet: cs,
baseURL: strings.TrimRight(cfg.PublicBaseURL, "/"),
}, nil
}
// --- auth.EmailSender implementation ---
// SendVerification renders the verification template and posts it to SES.
func (s *SESEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
return s.sendTemplated(ctx, to, "Verify your GuestGuard email",
TmplVerification, map[string]any{
"Name": name,
"Link": link,
})
}
// SendPasswordReset renders the reset template and posts it to SES.
func (s *SESEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
return s.sendTemplated(ctx, to, "Reset your GuestGuard password",
TmplPasswordReset, map[string]any{
"Name": name,
"Link": link,
"ExpiryHumane": "1 hour",
})
}
// SendGuest is used by the notifier worker for invitation / confirmation /
// reminder emails — anything addressed at a guest.
func (s *SESEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error) {
return s.sendTemplatedReturnID(ctx, to, subject, name, data)
}
// --- internals ---
func (s *SESEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error {
_, err := s.sendTemplatedReturnID(ctx, to, subject, name, data)
return err
}
func (s *SESEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
if data == nil {
data = map[string]any{}
}
data["Subject"] = subject
html, text, err := s.tpls.Render(name, data)
if err != nil {
return "", err
}
out, err := s.client.SendEmail(ctx, &sesv2.SendEmailInput{
FromEmailAddress: aws.String(s.from),
Destination: &types.Destination{ToAddresses: []string{to}},
ConfigurationSetName: s.configSet,
Content: &types.EmailContent{
Simple: &types.Message{
Subject: &types.Content{Data: aws.String(subject), Charset: aws.String("UTF-8")},
Body: &types.Body{
Html: &types.Content{Data: aws.String(html), Charset: aws.String("UTF-8")},
Text: &types.Content{Data: aws.String(text), Charset: aws.String("UTF-8")},
},
},
},
})
if err != nil {
return "", fmt.Errorf("ses: send: %w", err)
}
if out.MessageId == nil {
return "", nil
}
return *out.MessageId, nil
}
+97
View File
@@ -0,0 +1,97 @@
package notification
import (
"context"
"log/slog"
)
// EmailBackend names the chosen email delivery channel for telemetry +
// startup logging. Mostly a debugging aid — code paths don't branch on
// this value.
type EmailBackend string
const (
BackendResend EmailBackend = "resend"
BackendSMTP EmailBackend = "smtp"
BackendSES EmailBackend = "ses"
BackendLog EmailBackend = "log"
)
// EmailSenderConfig collects every email-related env var so the picker
// has a single, ordered place to decide which backend wins. Priority is
// Resend > SMTP > SES > Log — the first one with non-empty creds is used.
type EmailSenderConfig struct {
Resend ResendConfig
SMTP SMTPConfig
SES SESConfig
}
// CombinedEmailSender satisfies both the auth.EmailSender interface (for
// verification + reset emails) and GuestEmailDispatcher (for invitation,
// confirmation, reminder). One concrete value handles both audiences so
// callers don't end up with two configurations.
type CombinedEmailSender interface {
SendVerification(ctx context.Context, to, name, link string) error
SendPasswordReset(ctx context.Context, to, name, link string) error
SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error)
}
// PickEmailSender returns the configured email sender + which backend was
// chosen. Falls back to a logger stub if nothing is configured, so the
// service stays bootable in stripped-down dev environments.
func PickEmailSender(ctx context.Context, cfg EmailSenderConfig, tpls *Templates, logger *slog.Logger) (CombinedEmailSender, EmailBackend, error) {
switch {
case cfg.Resend.APIKey != "":
s, err := NewResendEmailSender(cfg.Resend, tpls)
if err != nil {
return nil, "", err
}
return s, BackendResend, nil
case cfg.SMTP.Host != "":
s, err := NewSMTPEmailSender(cfg.SMTP, tpls)
if err != nil {
return nil, "", err
}
return s, BackendSMTP, nil
case cfg.SES.FromEmail != "":
s, err := NewSESEmailSender(ctx, cfg.SES, tpls)
if err != nil {
return nil, "", err
}
return s, BackendSES, nil
}
return &logCombinedSender{logger: logger, tpls: tpls}, BackendLog, nil
}
// logCombinedSender is the dev fallback. Verification + reset emails come
// through as structured log lines (preserving the Block A behaviour);
// guest emails get rendered + dumped so engineers can eyeball the output.
type logCombinedSender struct {
logger *slog.Logger
tpls *Templates
}
func (l *logCombinedSender) SendVerification(_ context.Context, to, name, link string) error {
l.logger.Info("auth email (stub): verification", "to", to, "name", name, "link", link)
return nil
}
func (l *logCombinedSender) SendPasswordReset(_ context.Context, to, name, link string) error {
l.logger.Info("auth email (stub): password reset", "to", to, "name", name, "link", link)
return nil
}
func (l *logCombinedSender) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
if data == nil {
data = map[string]any{}
}
data["Subject"] = subject
_, text, err := l.tpls.Render(name, data)
if err != nil {
return "", err
}
l.logger.Info("guest email (stub)",
"to", to, "subject", subject, "template", string(name), "text_body", text,
)
return "log:" + string(name), nil
}
+39 -4
View File
@@ -68,7 +68,8 @@ type RecordParams struct {
Channel Channel
Type Type
Status Status
ProviderID string
ProviderID string // human-friendly id (e.g. "log:xyz")
ProviderMessageID string // provider's message id (Twilio SID, SES MessageId)
Error string
}
@@ -77,6 +78,10 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
if p.ProviderID != "" {
providerID = &p.ProviderID
}
var providerMsgID *string
if p.ProviderMessageID != "" {
providerMsgID = &p.ProviderMessageID
}
var errStr *string
if p.Error != "" {
errStr = &p.Error
@@ -90,14 +95,15 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
const q = `
INSERT INTO notifications (guest_id, channel, type, status, provider_id,
attempts, last_attempt, delivered_at, error)
VALUES ($1, $2, $3, $4, $5, 1, now(), $6, $7)
provider_message_id, attempts, last_attempt,
delivered_at, error)
VALUES ($1, $2, $3, $4, $5, $6, 1, now(), $7, $8)
RETURNING id
`
var id uuid.UUID
err := r.pool.QueryRow(ctx, q,
p.GuestID, string(p.Channel), string(p.Type), string(p.Status),
providerID, deliveredAt, errStr,
providerID, providerMsgID, deliveredAt, errStr,
).Scan(&id)
if err != nil {
return uuid.Nil, fmt.Errorf("record notification: %w", err)
@@ -105,6 +111,35 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
return id, nil
}
// MarkBounce records a bounce on the notification row identified by the
// provider's message id. Called from webhook handlers.
func (r *Repo) MarkBounce(ctx context.Context, providerMessageID, bounceType string) error {
_, err := r.pool.Exec(ctx, `
UPDATE notifications
SET status = 'bounced', bounce_type = $2, error = COALESCE(error, '')
WHERE provider_message_id = $1
`, providerMessageID, bounceType)
return err
}
// MarkComplaint records a complaint (spam report) for the same row.
func (r *Repo) MarkComplaint(ctx context.Context, providerMessageID string) error {
_, err := r.pool.Exec(ctx, `
UPDATE notifications SET complained = TRUE WHERE provider_message_id = $1
`, providerMessageID)
return err
}
// MarkDelivered moves a row from 'sent' to 'delivered' when the provider's
// delivery status webhook fires.
func (r *Repo) MarkDelivered(ctx context.Context, providerMessageID string) error {
_, err := r.pool.Exec(ctx, `
UPDATE notifications SET status = 'delivered', delivered_at = now()
WHERE provider_message_id = $1 AND status NOT IN ('bounced','failed')
`, providerMessageID)
return err
}
// LogSender pretends to send and just logs. Useful for Phase 3 demos and
// tests; concrete providers (Twilio/SES) plug in later.
type LogSender struct{}
+134
View File
@@ -0,0 +1,134 @@
package notification
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
// ResendConfig configures the Resend HTTP sender. APIKey is required;
// FromEmail is the sending address on a domain you've verified in the
// Resend dashboard.
type ResendConfig struct {
APIKey string
FromEmail string
FromName string
// HTTPClient overrides the default client (mostly so tests can point
// at httptest.Server). Leave nil for production.
HTTPClient *http.Client
BaseURL string // overrideable for tests; defaults to https://api.resend.com
}
// ResendEmailSender posts emails through https://api.resend.com/emails.
// Implements auth.EmailSender + GuestEmailDispatcher.
type ResendEmailSender struct {
cfg ResendConfig
tpls *Templates
from string
client *http.Client
url string
}
func NewResendEmailSender(cfg ResendConfig, tpls *Templates) (*ResendEmailSender, error) {
if cfg.APIKey == "" {
return nil, errors.New("resend: APIKey required")
}
if cfg.FromEmail == "" {
return nil, errors.New("resend: FromEmail required")
}
from := cfg.FromEmail
if cfg.FromName != "" {
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
}
cli := cfg.HTTPClient
if cli == nil {
cli = &http.Client{Timeout: 15 * time.Second}
}
base := cfg.BaseURL
if base == "" {
base = "https://api.resend.com"
}
return &ResendEmailSender{cfg: cfg, tpls: tpls, from: from, client: cli, url: base + "/emails"}, nil
}
// --- auth.EmailSender ---
func (s *ResendEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
_, err := s.sendTemplated(ctx, to, "Verify your GuestGuard email",
TmplVerification, map[string]any{"Name": name, "Link": link})
return err
}
func (s *ResendEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
_, err := s.sendTemplated(ctx, to, "Reset your GuestGuard password",
TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"})
return err
}
// --- GuestEmailDispatcher ---
func (s *ResendEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
return s.sendTemplated(ctx, to, subject, name, data)
}
// --- internals ---
type resendRequest struct {
From string `json:"from"`
To []string `json:"to"`
Subject string `json:"subject"`
HTML string `json:"html"`
Text string `json:"text"`
}
type resendResponse struct {
ID string `json:"id"`
Message string `json:"message,omitempty"`
}
func (s *ResendEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
if data == nil {
data = map[string]any{}
}
data["Subject"] = subject
html, text, err := s.tpls.Render(name, data)
if err != nil {
return "", err
}
body, _ := json.Marshal(resendRequest{
From: s.from,
To: []string{to},
Subject: subject,
HTML: html,
Text: text,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+s.cfg.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return "", fmt.Errorf("resend: do: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode >= 300 {
return "", fmt.Errorf("resend: status %d: %s", resp.StatusCode, string(respBody))
}
var parsed resendResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return "", fmt.Errorf("resend: parse: %w", err)
}
return parsed.ID, nil
}
@@ -0,0 +1,89 @@
package notification
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestResendSenderHappyPath(t *testing.T) {
tpls, err := NewTemplates()
if err != nil {
t.Fatalf("NewTemplates: %v", err)
}
var gotPath, gotAuth, gotContentType string
var gotBody map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
gotContentType = r.Header.Get("Content-Type")
b, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(b, &gotBody)
_ = json.NewEncoder(w).Encode(map[string]string{"id": "resend-abc-123"})
}))
t.Cleanup(srv.Close)
s, err := NewResendEmailSender(ResendConfig{
APIKey: "secret-key",
FromEmail: "no-reply@example.test",
FromName: "GuestGuard",
BaseURL: srv.URL,
}, tpls)
if err != nil {
t.Fatalf("NewResendEmailSender: %v", err)
}
id, err := s.SendGuest(context.Background(), "to@example.test", "You're invited",
TmplInvitation, map[string]any{
"GuestName": "Mira", "HostName": "Kay", "EventName": "Beach Day",
"Link": "https://example.test/rsvp/x",
})
if err != nil {
t.Fatalf("SendGuest: %v", err)
}
if id != "resend-abc-123" {
t.Fatalf("provider id: got %q want %q", id, "resend-abc-123")
}
if gotPath != "/emails" {
t.Errorf("path: got %q want /emails", gotPath)
}
if gotAuth != "Bearer secret-key" {
t.Errorf("auth: got %q", gotAuth)
}
if gotContentType != "application/json" {
t.Errorf("content-type: got %q", gotContentType)
}
if gotBody["from"] != "GuestGuard <no-reply@example.test>" {
t.Errorf("from: got %v", gotBody["from"])
}
if !strings.Contains(gotBody["html"].(string), "Beach Day") {
t.Errorf("html missing event name: %v", gotBody["html"])
}
if !strings.Contains(gotBody["text"].(string), "Mira") {
t.Errorf("text missing guest name: %v", gotBody["text"])
}
}
func TestResendSenderErrorPropagates(t *testing.T) {
tpls, _ := NewTemplates()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"invalid api key"}`))
}))
t.Cleanup(srv.Close)
s, err := NewResendEmailSender(ResendConfig{
APIKey: "bad", FromEmail: "x@y.test", BaseURL: srv.URL,
}, tpls)
if err != nil {
t.Fatalf("NewResendEmailSender: %v", err)
}
if err := s.SendVerification(context.Background(), "z@y.test", "Z", "https://link"); err == nil {
t.Fatal("expected error on 401")
}
}
+70
View File
@@ -0,0 +1,70 @@
package notification
import (
"context"
"fmt"
"log/slog"
"github.com/google/uuid"
)
// Router dispatches an OutboundMessage to the channel-appropriate sender.
// The notifier worker holds one Router and stays oblivious to whether the
// concrete backend is SES, Twilio, or a logger stub.
type Router struct {
email Sender
sms Sender
}
func NewRouter(email, sms Sender) *Router {
return &Router{email: email, sms: sms}
}
func (r *Router) Send(ctx context.Context, msg OutboundMessage) (string, error) {
switch msg.Channel {
case ChannelEmail:
if r.email == nil {
return "", fmt.Errorf("no email sender configured")
}
return r.email.Send(ctx, msg)
case ChannelSMS:
if r.sms == nil {
return "", fmt.Errorf("no sms sender configured")
}
return r.sms.Send(ctx, msg)
}
return "", fmt.Errorf("router: unknown channel %q", msg.Channel)
}
// LogGuestEmailDispatcher is the dev-mode dispatcher that renders the
// templated email and logs both bodies. Useful for local docker-compose
// before SES is configured — gives engineers the rendered output without
// needing a real inbox.
type LogGuestEmailDispatcher struct {
logger *slog.Logger
tpls *Templates
}
func NewLogGuestEmailDispatcher(logger *slog.Logger, tpls *Templates) *LogGuestEmailDispatcher {
return &LogGuestEmailDispatcher{logger: logger, tpls: tpls}
}
func (d *LogGuestEmailDispatcher) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
if data == nil {
data = map[string]any{}
}
data["Subject"] = subject
_, text, err := d.tpls.Render(name, data)
if err != nil {
return "", err
}
id := "log:" + uuid.NewString()
d.logger.Info("guest email (stub)",
"to", to,
"subject", subject,
"template", string(name),
"provider_message_id", id,
"text_body", text,
)
return id, nil
}
+104
View File
@@ -0,0 +1,104 @@
package notification
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
)
// GuestEmailDispatcher is the abstraction the notifier uses to send a
// templated email to a single guest. Both SES (production) and Log (dev)
// satisfy it.
type GuestEmailDispatcher interface {
SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error)
}
// EmailSender is the notification.Sender for ChannelEmail. It maps the
// generic OutboundMessage shape used by the notifier worker onto our
// templated SES / Log path, and honours the suppression list (a guest who
// unsubscribed must never receive another email).
type EmailSender struct {
dispatcher GuestEmailDispatcher
suppression *SuppressionRepo
}
func NewEmailSender(d GuestEmailDispatcher, s *SuppressionRepo) *EmailSender {
return &EmailSender{dispatcher: d, suppression: s}
}
func (e *EmailSender) Send(ctx context.Context, msg OutboundMessage) (string, error) {
if msg.GuestID == uuid.Nil {
return "", errors.New("missing guest id")
}
if msg.Channel != ChannelEmail {
return "", fmt.Errorf("EmailSender does not handle channel %q", msg.Channel)
}
to, _ := msg.Metadata["to"].(string)
if to == "" {
return "", errors.New("email recipient missing from metadata.to")
}
if e.suppression != nil {
yep, err := e.suppression.IsSuppressed(ctx, to)
if err != nil {
// On lookup error, fail safe by NOT sending — better than
// emailing an unsubscribed address.
return "", fmt.Errorf("suppression lookup: %w", err)
}
if yep {
return "suppressed:" + to, nil
}
}
tmpl := templateForType(msg.Type)
if tmpl == "" {
return "", fmt.Errorf("no template for type %q", msg.Type)
}
subject := msg.Subject
if subject == "" {
subject = defaultSubject(msg.Type, msg.Metadata)
}
data := map[string]any{}
for k, v := range msg.Metadata {
data[k] = v
}
return e.dispatcher.SendGuest(ctx, to, subject, tmpl, data)
}
func templateForType(t Type) TemplateName {
switch t {
case TypeInvitation:
return TmplInvitation
case TypeConfirmation:
return TmplConfirmation
case TypeReminder:
return TmplReminder
case TypeVerification:
return TmplVerification
}
return ""
}
func defaultSubject(t Type, meta map[string]any) string {
event, _ := meta["EventName"].(string)
switch t {
case TypeInvitation:
if event != "" {
return "You're invited — " + event
}
return "You're invited"
case TypeConfirmation:
if event != "" {
return "RSVP confirmed — " + event
}
return "RSVP confirmed"
case TypeReminder:
if event != "" {
return "Reminder: " + event
}
return "Reminder"
}
return "GuestGuard"
}
+108
View File
@@ -0,0 +1,108 @@
package notification
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/twilio/twilio-go"
openapi "github.com/twilio/twilio-go/rest/api/v2010"
)
// TwilioConfig configures the SMS sender. AccountSID + AuthToken pair
// authenticate the REST client; FromNumber is the verified Twilio number.
type TwilioConfig struct {
AccountSID string
AuthToken string
FromNumber string
MaxAttempts int
}
// TwilioSender implements notification.Sender for ChannelSMS. Retries on
// transient errors with exponential backoff (1s, 5s, 30s, 5m, 30m). Twilio
// surfaces a numeric ErrorCode for permanent failures (e.g. 21610
// unsubscribed, 21408 disabled region) — those return immediately.
type TwilioSender struct {
client *twilio.RestClient
from string
maxAttempt int
}
func NewTwilioSender(cfg TwilioConfig) (*TwilioSender, error) {
if cfg.AccountSID == "" || cfg.AuthToken == "" || cfg.FromNumber == "" {
return nil, errors.New("twilio: AccountSID / AuthToken / FromNumber are required")
}
cli := twilio.NewRestClientWithParams(twilio.ClientParams{
Username: cfg.AccountSID,
Password: cfg.AuthToken,
})
max := cfg.MaxAttempts
if max <= 0 {
max = 5
}
return &TwilioSender{client: cli, from: cfg.FromNumber, maxAttempt: max}, nil
}
func (t *TwilioSender) Send(ctx context.Context, msg OutboundMessage) (string, error) {
if msg.GuestID == uuid.Nil {
return "", errors.New("missing guest id")
}
if msg.Channel != ChannelSMS {
return "", fmt.Errorf("TwilioSender does not handle channel %q", msg.Channel)
}
to, _ := msg.Metadata["phone"].(string)
if to == "" {
return "", errors.New("sms recipient missing from metadata.phone")
}
body := msg.Body
if body == "" {
body = msg.Subject
}
if body == "" {
return "", errors.New("sms body is empty")
}
params := &openapi.CreateMessageParams{}
params.SetTo(to)
params.SetFrom(t.from)
params.SetBody(body)
backoff := []time.Duration{0, time.Second, 5 * time.Second, 30 * time.Second, 5 * time.Minute, 30 * time.Minute}
var lastErr error
for attempt := 0; attempt < t.maxAttempt; attempt++ {
if attempt < len(backoff) && backoff[attempt] > 0 {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(backoff[attempt]):
}
}
resp, err := t.client.Api.CreateMessage(params)
if err == nil && resp != nil && resp.Sid != nil {
return *resp.Sid, nil
}
lastErr = err
if !isTwilioRetryable(err) {
return "", fmt.Errorf("twilio: send: %w", err)
}
}
return "", fmt.Errorf("twilio: send (after %d attempts): %w", t.maxAttempt, lastErr)
}
// isTwilioRetryable returns true for transient failures (network, 5xx).
// Twilio's permanent error codes (21xxx range) are not retried.
func isTwilioRetryable(err error) bool {
if err == nil {
return false
}
msg := err.Error()
// Cheap heuristic: permanent codes are in the 21xxx range; everything
// else (timeouts, 503s, DNS hiccups) is fair game.
if strings.Contains(msg, "21610") || strings.Contains(msg, "21408") || strings.Contains(msg, "21211") {
return false
}
return true
}
+238
View File
@@ -0,0 +1,238 @@
package notification
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/hex"
"errors"
"fmt"
"net"
"net/smtp"
"strconv"
"strings"
"time"
)
// SMTPConfig describes one SMTP relay. Username/Password are optional —
// local relays like Mailpit accept anonymous SMTP. TLS modes:
//
// - "starttls" upgrade after EHLO (most relays, default)
// - "implicit" TLS handshake before SMTP (port 465)
// - "none" plain socket — only acceptable on a trusted LAN
type SMTPConfig struct {
Host string
Port int
Username string
Password string
FromEmail string
FromName string
TLS string
}
// SMTPEmailSender implements auth.EmailSender (verification/reset) AND
// GuestEmailDispatcher (invitation/confirmation/reminder) on top of any
// SMTP relay. Used for Mailpit in dev; works against Gmail, Fastmail, etc.
// in production if the user prefers plain SMTP over an HTTP API.
type SMTPEmailSender struct {
cfg SMTPConfig
tpls *Templates
from string
}
func NewSMTPEmailSender(cfg SMTPConfig, tpls *Templates) (*SMTPEmailSender, error) {
if cfg.Host == "" {
return nil, errors.New("smtp: Host required")
}
if cfg.Port <= 0 {
cfg.Port = 587
}
if cfg.FromEmail == "" {
return nil, errors.New("smtp: FromEmail required")
}
if cfg.TLS == "" {
cfg.TLS = "starttls"
}
from := cfg.FromEmail
if cfg.FromName != "" {
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
}
return &SMTPEmailSender{cfg: cfg, tpls: tpls, from: from}, nil
}
// --- auth.EmailSender ---
func (s *SMTPEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
return s.sendTemplated(ctx, to, "Verify your GuestGuard email",
TmplVerification, map[string]any{"Name": name, "Link": link})
}
func (s *SMTPEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
return s.sendTemplated(ctx, to, "Reset your GuestGuard password",
TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"})
}
// --- GuestEmailDispatcher ---
func (s *SMTPEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
return s.sendTemplatedReturnID(ctx, to, subject, name, data)
}
// --- internals ---
func (s *SMTPEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error {
_, err := s.sendTemplatedReturnID(ctx, to, subject, name, data)
return err
}
func (s *SMTPEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
if data == nil {
data = map[string]any{}
}
data["Subject"] = subject
html, text, err := s.tpls.Render(name, data)
if err != nil {
return "", err
}
msgID := generateMessageID(s.cfg.FromEmail)
body := buildMIMEMessage(mimeMessage{
MessageID: msgID,
From: s.from,
To: to,
Subject: subject,
Text: text,
HTML: html,
})
if err := s.dial(ctx, []string{to}, body); err != nil {
return "", err
}
return msgID, nil
}
func (s *SMTPEmailSender) dial(ctx context.Context, to []string, body []byte) error {
addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(s.cfg.Port))
d := &net.Dialer{Timeout: 10 * time.Second}
deadline, ok := ctx.Deadline()
if ok {
d.Deadline = deadline
}
var (
conn net.Conn
err error
)
switch strings.ToLower(s.cfg.TLS) {
case "implicit":
conn, err = tls.DialWithDialer(d, "tcp", addr, &tls.Config{ServerName: s.cfg.Host})
default:
conn, err = d.DialContext(ctx, "tcp", addr)
}
if err != nil {
return fmt.Errorf("smtp: dial: %w", err)
}
c, err := smtp.NewClient(conn, s.cfg.Host)
if err != nil {
conn.Close()
return fmt.Errorf("smtp: new client: %w", err)
}
defer c.Close()
if strings.ToLower(s.cfg.TLS) == "starttls" {
if ok, _ := c.Extension("STARTTLS"); ok {
if err := c.StartTLS(&tls.Config{ServerName: s.cfg.Host}); err != nil {
return fmt.Errorf("smtp: starttls: %w", err)
}
}
}
if s.cfg.Username != "" {
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host)
if err := c.Auth(auth); err != nil {
return fmt.Errorf("smtp: auth: %w", err)
}
}
if err := c.Mail(s.cfg.FromEmail); err != nil {
return fmt.Errorf("smtp: MAIL FROM: %w", err)
}
for _, rcpt := range to {
if err := c.Rcpt(rcpt); err != nil {
return fmt.Errorf("smtp: RCPT TO %s: %w", rcpt, err)
}
}
w, err := c.Data()
if err != nil {
return fmt.Errorf("smtp: DATA: %w", err)
}
if _, err := w.Write(body); err != nil {
_ = w.Close()
return fmt.Errorf("smtp: write body: %w", err)
}
if err := w.Close(); err != nil {
return fmt.Errorf("smtp: close body: %w", err)
}
return c.Quit()
}
type mimeMessage struct {
MessageID string
From string
To string
Subject string
Text string
HTML string
}
// buildMIMEMessage assembles an RFC 5322 message with a multipart/alternative
// body so receiving clients pick HTML when they can and fall back to text
// otherwise.
func buildMIMEMessage(m mimeMessage) []byte {
boundary := randomBoundary()
var b strings.Builder
b.WriteString("Message-ID: <" + m.MessageID + ">\r\n")
b.WriteString("Date: " + time.Now().UTC().Format(time.RFC1123Z) + "\r\n")
b.WriteString("From: " + m.From + "\r\n")
b.WriteString("To: " + m.To + "\r\n")
b.WriteString("Subject: " + m.Subject + "\r\n")
b.WriteString("MIME-Version: 1.0\r\n")
b.WriteString("Content-Type: multipart/alternative; boundary=" + boundary + "\r\n")
b.WriteString("\r\n")
b.WriteString("--" + boundary + "\r\n")
b.WriteString("Content-Type: text/plain; charset=UTF-8\r\n")
b.WriteString("Content-Transfer-Encoding: 8bit\r\n")
b.WriteString("\r\n")
b.WriteString(m.Text)
if !strings.HasSuffix(m.Text, "\n") {
b.WriteString("\r\n")
}
b.WriteString("--" + boundary + "\r\n")
b.WriteString("Content-Type: text/html; charset=UTF-8\r\n")
b.WriteString("Content-Transfer-Encoding: 8bit\r\n")
b.WriteString("\r\n")
b.WriteString(m.HTML)
if !strings.HasSuffix(m.HTML, "\n") {
b.WriteString("\r\n")
}
b.WriteString("--" + boundary + "--\r\n")
return []byte(b.String())
}
func randomBoundary() string {
var buf [16]byte
_, _ = rand.Read(buf[:])
return "gg=" + hex.EncodeToString(buf[:])
}
func generateMessageID(from string) string {
var buf [12]byte
_, _ = rand.Read(buf[:])
domain := "guestguard.local"
if at := strings.LastIndexByte(from, '@'); at >= 0 && at+1 < len(from) {
domain = from[at+1:]
}
return hex.EncodeToString(buf[:]) + "@" + domain
}
+49
View File
@@ -0,0 +1,49 @@
package notification
import (
"strings"
"testing"
)
func TestBuildMIMEMessageStructure(t *testing.T) {
body := buildMIMEMessage(mimeMessage{
MessageID: "abc@example.test",
From: "GuestGuard <no-reply@example.test>",
To: "to@example.test",
Subject: "Verify your GuestGuard email",
Text: "Hi Mira, please verify.",
HTML: "<p>Hi Mira, please verify.</p>",
})
s := string(body)
checks := []string{
"Message-ID: <abc@example.test>",
"From: GuestGuard <no-reply@example.test>",
"To: to@example.test",
"Subject: Verify your GuestGuard email",
"MIME-Version: 1.0",
"Content-Type: multipart/alternative; boundary=",
"Content-Type: text/plain; charset=UTF-8",
"Content-Type: text/html; charset=UTF-8",
"Hi Mira, please verify.",
"<p>Hi Mira, please verify.</p>",
}
for _, want := range checks {
if !strings.Contains(s, want) {
t.Errorf("MIME body missing %q\n---\n%s", want, s)
}
}
}
func TestGenerateMessageIDIncludesDomain(t *testing.T) {
id := generateMessageID("no-reply@example.test")
if !strings.HasSuffix(id, "@example.test") {
t.Fatalf("message id has wrong domain: %s", id)
}
}
func TestGenerateMessageIDFallback(t *testing.T) {
id := generateMessageID("not-an-email")
if !strings.HasSuffix(id, "@guestguard.local") {
t.Fatalf("expected fallback domain: %s", id)
}
}
+95
View File
@@ -0,0 +1,95 @@
package notification
import (
"context"
"errors"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/alchemistkay/guestguard/internal/storage"
)
// SuppressionSource categorises why an address landed on the suppression
// list. Bounces + complaints come from provider webhooks; "user" is set
// when a guest clicks an unsubscribe link.
type SuppressionSource string
const (
SuppressionBounce SuppressionSource = "bounce"
SuppressionComplaint SuppressionSource = "complaint"
SuppressionManual SuppressionSource = "manual"
SuppressionUser SuppressionSource = "user"
)
// SuppressionRepo manages the unsubscribes table — a flat suppression list
// of email addresses that must never receive another email regardless of
// notification type.
type SuppressionRepo struct {
pool *pgxpool.Pool
}
func NewSuppressionRepo(db *storage.DB) *SuppressionRepo {
return &SuppressionRepo{pool: db.Pool}
}
// IsSuppressed returns true if `email` is on the suppression list.
// Empty / unparseable addresses are treated as not-suppressed; the caller
// is expected to validate before sending.
func (r *SuppressionRepo) IsSuppressed(ctx context.Context, email string) (bool, error) {
email = normaliseEmail(email)
if email == "" {
return false, nil
}
var exists bool
err := r.pool.QueryRow(ctx,
`SELECT EXISTS (SELECT 1 FROM unsubscribes WHERE email = $1)`,
email,
).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}
// Add records `email` on the suppression list. Idempotent — repeated calls
// keep the earliest entry's timestamp.
func (r *SuppressionRepo) Add(ctx context.Context, email, reason string, src SuppressionSource) error {
email = normaliseEmail(email)
if email == "" {
return errors.New("empty email")
}
_, err := r.pool.Exec(ctx, `
INSERT INTO unsubscribes (email, reason, source)
VALUES ($1, NULLIF($2, ''), $3)
ON CONFLICT (email) DO NOTHING
`, email, reason, string(src))
return err
}
// Get returns the suppression record for `email`, or pgx.ErrNoRows if not
// found. Mostly used by tests / admin tooling.
func (r *SuppressionRepo) Get(ctx context.Context, email string) (string, SuppressionSource, error) {
var reason *string
var source string
err := r.pool.QueryRow(ctx,
`SELECT reason, source FROM unsubscribes WHERE email = $1`,
normaliseEmail(email),
).Scan(&reason, &source)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", "", err
}
return "", "", err
}
r2 := ""
if reason != nil {
r2 = *reason
}
return r2, SuppressionSource(source), nil
}
func normaliseEmail(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
+116
View File
@@ -0,0 +1,116 @@
package notification
import (
"bytes"
"embed"
"fmt"
htmltemplate "html/template"
"io/fs"
"path"
"strings"
texttemplate "text/template"
)
//go:embed templates
var templatesFS embed.FS
// TemplateName names one of the email templates. There must be a
// matching `<name>.html` and `<name>.txt` under templates/.
type TemplateName string
const (
TmplVerification TemplateName = "verification"
TmplPasswordReset TemplateName = "reset"
TmplInvitation TemplateName = "invitation"
TmplConfirmation TemplateName = "confirmation"
TmplReminder TemplateName = "reminder"
)
// Templates renders branded transactional emails for both HTML and
// plaintext bodies. Loaded once at construction; thread-safe afterwards.
//
// Each page-level HTML template gets its own *html/template.Template
// holding a copy of the `base.html` partial. This avoids html/template's
// per-template contextual-escape pass interfering between pages that all
// define a sibling named "body".
type Templates struct {
html map[string]*htmltemplate.Template // page-name (no ext) -> root
text *texttemplate.Template
}
func NewTemplates() (*Templates, error) {
root, err := fs.Sub(templatesFS, "templates")
if err != nil {
return nil, fmt.Errorf("templates fs: %w", err)
}
baseHTML, err := fs.ReadFile(root, "base.html")
if err != nil {
return nil, fmt.Errorf("read base.html: %w", err)
}
out := &Templates{
html: make(map[string]*htmltemplate.Template),
text: texttemplate.New("__root__"),
}
walk := func(p string, d fs.DirEntry, _ error) error {
if d == nil || d.IsDir() {
return nil
}
base := path.Base(p)
if base == "base.html" {
return nil // partial — folded into each page template below
}
b, err := fs.ReadFile(root, p)
if err != nil {
return err
}
switch {
case strings.HasSuffix(p, ".html"):
name := strings.TrimSuffix(base, ".html")
t := htmltemplate.New(name)
if _, err := t.Parse(string(baseHTML)); err != nil {
return fmt.Errorf("parse _base for %s: %w", p, err)
}
if _, err := t.Parse(string(b)); err != nil {
return fmt.Errorf("parse %s: %w", p, err)
}
out.html[name] = t
case strings.HasSuffix(p, ".txt"):
if _, err := out.text.New(base).Parse(string(b)); err != nil {
return fmt.Errorf("parse %s: %w", p, err)
}
}
return nil
}
if err := fs.WalkDir(root, ".", walk); err != nil {
return nil, err
}
return out, nil
}
// Render returns (htmlBody, textBody) for the named template using data.
func (t *Templates) Render(name TemplateName, data map[string]any) (htmlBody, textBody string, err error) {
if data == nil {
data = map[string]any{}
}
if _, ok := data["Subject"]; !ok {
data["Subject"] = "GuestGuard"
}
htmlTpl, ok := t.html[string(name)]
if !ok {
return "", "", fmt.Errorf("unknown html template %q", name)
}
var hBuf, tBuf bytes.Buffer
// Each page-root template's entry point is "base" (defined by base.html).
if err := htmlTpl.ExecuteTemplate(&hBuf, "base", data); err != nil {
return "", "", fmt.Errorf("render html %s: %w", name, err)
}
if err := t.text.ExecuteTemplate(&tBuf, string(name)+".txt", data); err != nil {
return "", "", fmt.Errorf("render text %s: %w", name, err)
}
return hBuf.String(), tBuf.String(), nil
}
+36
View File
@@ -0,0 +1,36 @@
{{define "base"}}<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>{{.Subject}}</title>
</head>
<body style="margin:0;padding:0;background:#f7f7f8;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,Helvetica,Arial,sans-serif;color:#0f172a;">
<table role="presentation" width="100%" cellpadding="0" cellspacing="0" style="background:#f7f7f8;padding:32px 16px;">
<tr><td align="center">
<table role="presentation" width="560" cellpadding="0" cellspacing="0" style="background:#ffffff;border-radius:12px;overflow:hidden;box-shadow:0 1px 2px rgba(15,23,42,0.05);">
<tr>
<td style="background:#0a0a0a;padding:20px 28px;">
<span style="display:inline-block;width:10px;height:10px;background:#22c55e;border-radius:50%;vertical-align:middle;"></span>
<span style="color:#fafafa;font-weight:600;font-size:16px;margin-left:8px;vertical-align:middle;">GuestGuard</span>
</td>
</tr>
<tr>
<td style="padding:32px 28px 24px;font-size:15px;line-height:1.55;color:#0f172a;">
{{block "body" .}}{{end}}
</td>
</tr>
<tr>
<td style="padding:0 28px 28px;">
<hr style="border:none;border-top:1px solid #e5e7eb;margin:0 0 16px;">
<p style="font-size:12px;color:#64748b;margin:0;">
You're receiving this because of activity on your GuestGuard account.
{{if .UnsubscribeLink}}<br>If you'd rather not get emails like this, <a href="{{.UnsubscribeLink}}" style="color:#16a34a;">unsubscribe here</a>.{{end}}
</p>
</td>
</tr>
</table>
</td></tr>
</table>
</body>
</html>{{end}}
@@ -0,0 +1,11 @@
{{define "body"}}
<h1 style="font-size:20px;margin:0 0 12px;">RSVP received</h1>
<p style="margin:0 0 16px;">Hi {{.GuestName}},</p>
<p style="margin:0 0 16px;">Thanks for letting {{.HostName}} know — your RSVP for <strong>{{.EventName}}</strong> is confirmed as <strong>{{.Response}}</strong>{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.</p>
{{if or .Venue .EventDate}}
<p style="margin:0 0 16px;color:#64748b;font-size:14px;">
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
</p>
{{end}}
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">You'll get a reminder closer to the date. If your plans change, use the same invitation link to update your reply.</p>
{{end}}
@@ -0,0 +1,10 @@
Hi {{.GuestName}},
Your RSVP for "{{.EventName}}" is confirmed as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
You'll get a reminder closer to the date. If your plans change, use the
same invitation link to update your reply.
— GuestGuard
@@ -0,0 +1,15 @@
{{define "body"}}
<p style="font-size:13px;letter-spacing:0.2em;text-transform:uppercase;color:#16a34a;margin:0 0 16px;">✦ You're invited</p>
<h1 style="font-size:24px;margin:0 0 4px;color:#0a0a0a;">{{.EventName}}</h1>
{{if or .Venue .EventDate}}
<p style="margin:0 0 24px;color:#64748b;font-size:14px;">
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
</p>
{{end}}
<p style="margin:0 0 20px;">Hi {{.GuestName}}, {{.HostName}} would love to know if you can make it. Use the personal link below to RSVP.</p>
<p style="margin:0 0 24px;text-align:center;">
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">RSVP now</a>
</p>
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
{{end}}
@@ -0,0 +1,11 @@
You're invited — {{.EventName}}
Hi {{.GuestName}},
{{.HostName}} would love to know if you can make it{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.
RSVP here:
{{.Link}}
— GuestGuard
@@ -0,0 +1,7 @@
{{define "body"}}
<h1 style="font-size:20px;margin:0 0 12px;">Reminder: {{.EventName}}</h1>
<p style="margin:0 0 16px;">Hi {{.GuestName}},</p>
<p style="margin:0 0 16px;">Just a quick reminder that <strong>{{.EventName}}</strong> is coming up{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.</p>
{{if .Response}}<p style="margin:0 0 16px;">You're down as <strong>{{.Response}}</strong>{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.</p>{{end}}
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">Need to change your plans? Use your invitation link.</p>
{{end}}
@@ -0,0 +1,7 @@
Hi {{.GuestName}},
Reminder: {{.EventName}}{{if .EventDate}} — {{.EventDate}}{{end}}{{if .Venue}} at {{.Venue}}{{end}}.
{{if .Response}}You're down as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.{{end}}
— GuestGuard
@@ -0,0 +1,11 @@
{{define "body"}}
<h1 style="font-size:20px;margin:0 0 12px;">Reset your password</h1>
<p style="margin:0 0 16px;">Hi {{.Name}},</p>
<p style="margin:0 0 20px;">We received a request to reset the password on your GuestGuard account. Tap the button to choose a new one — the link is valid for {{.ExpiryHumane}}.</p>
<p style="margin:0 0 24px;text-align:center;">
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">Choose a new password</a>
</p>
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">If you didn't ask to reset your password, you can ignore this email — your current password is unchanged.</p>
{{end}}
+11
View File
@@ -0,0 +1,11 @@
Hi {{.Name}},
We received a request to reset your GuestGuard password. The link is valid
for {{.ExpiryHumane}}:
{{.Link}}
If you didn't ask to reset your password, you can ignore this email — your
current password is unchanged.
— GuestGuard
@@ -0,0 +1,11 @@
{{define "body"}}
<h1 style="font-size:20px;margin:0 0 12px;">Verify your email</h1>
<p style="margin:0 0 16px;">Hi {{.Name}}, welcome to GuestGuard.</p>
<p style="margin:0 0 20px;">To finish setting up your account, please confirm this is your email address.</p>
<p style="margin:0 0 24px;text-align:center;">
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">Verify email</a>
</p>
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">If you didn't sign up for GuestGuard, you can ignore this email.</p>
{{end}}
@@ -0,0 +1,9 @@
Hi {{.Name}}, welcome to GuestGuard.
Please verify your email by visiting:
{{.Link}}
If you didn't sign up for GuestGuard, you can ignore this email.
— GuestGuard
+101
View File
@@ -0,0 +1,101 @@
package notification
import (
"strings"
"testing"
)
func TestRenderAllTemplates(t *testing.T) {
tpls, err := NewTemplates()
if err != nil {
t.Fatalf("NewTemplates: %v", err)
}
cases := []struct {
name TemplateName
data map[string]any
wantHTML []string // substrings expected in HTML body
wantText []string // substrings expected in text body
}{
{
name: TmplVerification,
data: map[string]any{
"Name": "Kay",
"Link": "https://example.test/verify-email?token=x",
"Subject": "Verify your GuestGuard email",
"UnsubscribeLink": "https://example.test/unsubscribe/abc",
},
wantHTML: []string{"Verify your email", "Kay", "Verify email", "https://example.test/verify-email?token=x", "unsubscribe here"},
wantText: []string{"Hi Kay", "https://example.test/verify-email?token=x"},
},
{
name: TmplPasswordReset,
data: map[string]any{
"Name": "Kay",
"Link": "https://example.test/reset-password/abc",
"ExpiryHumane": "1 hour",
},
wantHTML: []string{"Reset your password", "1 hour", "https://example.test/reset-password/abc"},
wantText: []string{"reset your GuestGuard password", "1 hour"},
},
{
name: TmplInvitation,
data: map[string]any{
"GuestName": "Mira",
"HostName": "Kay",
"EventName": "Beach Day",
"Venue": "Ocean Park",
"EventDate": "Sat 14 Jun, 4pm",
"Link": "https://example.test/rsvp/tok_x",
},
wantHTML: []string{"You're invited", "Beach Day", "Mira", "Ocean Park", "RSVP now", "https://example.test/rsvp/tok_x"},
wantText: []string{"Beach Day", "Mira", "RSVP here", "https://example.test/rsvp/tok_x"},
},
{
name: TmplConfirmation,
data: map[string]any{
"GuestName": "Mira",
"HostName": "Kay",
"EventName": "Beach Day",
"Venue": "Ocean Park",
"EventDate": "Sat 14 Jun, 4pm",
"Response": "attending",
"PlusOnes": 2,
},
wantHTML: []string{"RSVP received", "Beach Day", "attending", "2 plus-ones"},
wantText: []string{"Beach Day", "attending", "+2"},
},
{
name: TmplReminder,
data: map[string]any{
"GuestName": "Mira",
"EventName": "Beach Day",
"EventDate": "Sat 14 Jun, 4pm",
"Venue": "Ocean Park",
"Response": "attending",
"PlusOnes": 1,
},
wantHTML: []string{"Reminder", "Beach Day", "Ocean Park", "1 plus-one"},
wantText: []string{"Reminder", "Beach Day", "(+1)"},
},
}
for _, tc := range cases {
t.Run(string(tc.name), func(t *testing.T) {
html, text, err := tpls.Render(tc.name, tc.data)
if err != nil {
t.Fatalf("render: %v", err)
}
for _, s := range tc.wantHTML {
if !strings.Contains(html, s) {
t.Errorf("html missing %q\n---\n%s", s, html)
}
}
for _, s := range tc.wantText {
if !strings.Contains(text, s) {
t.Errorf("text missing %q\n---\n%s", s, text)
}
}
})
}
}
+60
View File
@@ -0,0 +1,60 @@
package notification
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"strings"
)
// UnsubscribeSigner mints + verifies tamper-proof unsubscribe tokens. Each
// token encodes the email address and an HMAC-SHA256 of the address under
// a server-side secret. The token has no TTL — unsubscribe links should
// keep working forever.
//
// Token shape: base64url(email) + "." + base64url(hmac)
type UnsubscribeSigner struct {
secret []byte
}
func NewUnsubscribeSigner(secret string) *UnsubscribeSigner {
return &UnsubscribeSigner{secret: []byte(secret)}
}
// Sign returns a URL-safe token for the email address. Empty input → empty
// token (caller should validate input first).
func (s *UnsubscribeSigner) Sign(email string) string {
if email == "" {
return ""
}
email = normaliseEmail(email)
mac := hmac.New(sha256.New, s.secret)
mac.Write([]byte(email))
return base64.RawURLEncoding.EncodeToString([]byte(email)) +
"." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
// Verify decodes the token and confirms the HMAC matches, returning the
// owning email address.
func (s *UnsubscribeSigner) Verify(token string) (string, error) {
dot := strings.IndexByte(token, '.')
if dot < 0 {
return "", errors.New("malformed unsubscribe token")
}
emailB, err := base64.RawURLEncoding.DecodeString(token[:dot])
if err != nil {
return "", err
}
sigB, err := base64.RawURLEncoding.DecodeString(token[dot+1:])
if err != nil {
return "", err
}
mac := hmac.New(sha256.New, s.secret)
mac.Write(emailB)
want := mac.Sum(nil)
if !hmac.Equal(sigB, want) {
return "", errors.New("unsubscribe signature mismatch")
}
return string(emailB), nil
}
+69
View File
@@ -0,0 +1,69 @@
package ratelimit
import (
"encoding/json"
"log/slog"
"net/http"
"strconv"
"time"
)
// KeyFunc derives the rate-limit key from a request — e.g. IP, IP+email,
// authenticated user id, path token, etc. An empty return value bypasses
// the limiter (handy when a key isn't available yet, like an email that
// hasn't been parsed out of the body).
type KeyFunc func(r *http.Request) string
// Rule names a single sliding-window budget.
type Rule struct {
Name string
Limit int
Window time.Duration
}
// Middleware returns an http middleware that applies `rule` to incoming
// requests using `keyFn`. On limit, it writes 429 with a Retry-After header
// and a JSON body. If the limiter itself errors, requests are allowed
// (fail-open) — degraded protection is better than total outage.
func (l *Limiter) Middleware(rule Rule, keyFn KeyFunc, logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := keyFn(r)
if key == "" {
next.ServeHTTP(w, r)
return
}
res, err := l.Allow(r.Context(), rule.Name, key, rule.Limit, rule.Window)
if err != nil {
if logger != nil {
logger.Warn("ratelimit error (failing open)",
"rule", rule.Name, "err", err)
}
next.ServeHTTP(w, r)
return
}
if !res.Allowed {
retrySecs := int(res.RetryAfter.Round(time.Second).Seconds())
if retrySecs < 1 {
retrySecs = 1
}
w.Header().Set("Retry-After", strconv.Itoa(retrySecs))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": "rate limit exceeded",
"retry_after": retrySecs,
})
if logger != nil {
logger.Info("ratelimit blocked",
"rule", rule.Name,
"key", key,
"retry_after", retrySecs,
)
}
return
}
next.ServeHTTP(w, r)
})
}
}
+134
View File
@@ -0,0 +1,134 @@
// Package ratelimit implements a sliding-window rate limiter backed by
// Redis sorted sets. Each call is atomic via a Lua script: it sweeps
// entries older than the window, returns the current count, and (when
// under the limit) records the new hit. Block C's "INCR + EXPIRE or a Lua
// script for atomicity" requirement.
package ratelimit
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// Result is the outcome of one Allow check.
type Result struct {
Allowed bool
Count int // current count within the window (post-increment when allowed)
Limit int
RetryAfter time.Duration // populated when Allowed=false
}
// Limiter checks rate-limit budgets against Redis.
type Limiter struct {
client *redis.Client
script *redis.Script
prefix string
now func() time.Time
}
// New builds a Limiter against the given Redis client. The prefix namespaces
// all keys (defaults to "rl").
func New(client *redis.Client, prefix string) *Limiter {
if prefix == "" {
prefix = "rl"
}
return &Limiter{
client: client,
script: redis.NewScript(slidingWindowScript),
prefix: prefix,
now: time.Now,
}
}
// Allow consumes one unit of budget under (name, key) against `limit` events
// per `window`. Returns Allowed=true and the new count, or Allowed=false
// with RetryAfter set to roughly the duration until the oldest hit ages out.
func (l *Limiter) Allow(ctx context.Context, name, key string, limit int, window time.Duration) (Result, error) {
if limit <= 0 {
return Result{}, errors.New("ratelimit: limit must be positive")
}
if window <= 0 {
return Result{}, errors.New("ratelimit: window must be positive")
}
member, err := randomMember()
if err != nil {
return Result{}, err
}
now := l.now().UnixMilli()
windowMS := window.Milliseconds()
redisKey := fmt.Sprintf("%s:%s:%s", l.prefix, name, key)
out, err := l.script.Run(ctx, l.client,
[]string{redisKey},
now, windowMS, limit, member,
).Int64Slice()
if err != nil {
return Result{}, fmt.Errorf("ratelimit: redis: %w", err)
}
if len(out) != 3 {
return Result{}, fmt.Errorf("ratelimit: bad lua reply: %v", out)
}
r := Result{
Count: int(out[1]),
Limit: limit,
}
if out[0] == 0 {
r.Allowed = true
} else {
r.RetryAfter = time.Duration(out[2]) * time.Millisecond
if r.RetryAfter <= 0 {
r.RetryAfter = time.Second
}
}
return r, nil
}
func randomMember() (string, error) {
var buf [12]byte
if _, err := rand.Read(buf[:]); err != nil {
return "", err
}
return hex.EncodeToString(buf[:]), nil
}
// Sliding-window check + record, atomic in Redis.
//
// KEYS[1] = bucket key
// ARGV[1] = now (unix ms)
// ARGV[2] = window (ms)
// ARGV[3] = limit
// ARGV[4] = unique member to insert when allowed
// returns { blocked, count, retryAfterMs }
const slidingWindowScript = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local member = ARGV[4]
local cutoff = now - window
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
local count = redis.call('ZCARD', key)
if count >= limit then
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retry = window
if oldest[2] then
retry = (tonumber(oldest[2]) + window) - now
if retry < 1 then retry = 1 end
end
return {1, count, retry}
end
redis.call('ZADD', key, now, member)
redis.call('PEXPIRE', key, window)
return {0, count + 1, 0}
`
+110
View File
@@ -0,0 +1,110 @@
package ratelimit
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
)
func newTestLimiter(t *testing.T) (*Limiter, *miniredis.Miniredis) {
t.Helper()
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("miniredis: %v", err)
}
t.Cleanup(mr.Close)
cli := redis.NewClient(&redis.Options{Addr: mr.Addr()})
t.Cleanup(func() { _ = cli.Close() })
return New(cli, "test"), mr
}
func TestLimiterAllowsBelowLimit(t *testing.T) {
l, _ := newTestLimiter(t)
ctx := context.Background()
for i := 1; i <= 3; i++ {
r, err := l.Allow(ctx, "signup", "1.2.3.4", 3, time.Minute)
if err != nil {
t.Fatalf("allow #%d: %v", i, err)
}
if !r.Allowed {
t.Fatalf("hit %d should be allowed, got %+v", i, r)
}
if r.Count != i {
t.Fatalf("hit %d count: got %d want %d", i, r.Count, i)
}
}
}
func TestLimiterBlocksAtLimit(t *testing.T) {
l, _ := newTestLimiter(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
if _, err := l.Allow(ctx, "signup", "ip", 3, time.Minute); err != nil {
t.Fatal(err)
}
}
r, err := l.Allow(ctx, "signup", "ip", 3, time.Minute)
if err != nil {
t.Fatal(err)
}
if r.Allowed {
t.Fatalf("4th hit should be blocked: %+v", r)
}
if r.RetryAfter <= 0 || r.RetryAfter > time.Minute {
t.Fatalf("retry-after out of range: %v", r.RetryAfter)
}
}
func TestLimiterWindowSlides(t *testing.T) {
l, mr := newTestLimiter(t)
ctx := context.Background()
// Inject a controllable clock so we can advance time in miniredis +
// the limiter consistently.
base := time.Unix(1_700_000_000, 0)
l.now = func() time.Time { return base }
for i := 0; i < 3; i++ {
if _, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute); err != nil {
t.Fatal(err)
}
}
// Slide past the window. miniredis honours TTLs we already set so
// FastForward is the trustworthy primitive here.
l.now = func() time.Time { return base.Add(2 * time.Minute) }
mr.FastForward(2 * time.Minute)
r, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute)
if err != nil {
t.Fatal(err)
}
if !r.Allowed {
t.Fatalf("expected allow after window: %+v", r)
}
}
func TestLimiterIsolatesKeys(t *testing.T) {
l, _ := newTestLimiter(t)
ctx := context.Background()
// Exhaust budget for one key — others should not be affected.
for i := 0; i < 2; i++ {
if _, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute); err != nil {
t.Fatal(err)
}
}
blocked, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute)
if err != nil {
t.Fatal(err)
}
if blocked.Allowed {
t.Fatal("key a should be blocked")
}
other, err := l.Allow(ctx, "login", "b@x.com", 2, time.Minute)
if err != nil {
t.Fatal(err)
}
if !other.Allowed {
t.Fatalf("unrelated key b should still be allowed: %+v", other)
}
}
+267
View File
@@ -0,0 +1,267 @@
package storage
import (
"context"
"errors"
"net/netip"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/alchemistkay/guestguard/internal/domain"
)
// EmailVerificationRepo manages single-use email verification tokens.
type EmailVerificationRepo struct {
pool *pgxpool.Pool
}
func NewEmailVerificationRepo(db *DB) *EmailVerificationRepo {
return &EmailVerificationRepo{pool: db.Pool}
}
func (r *EmailVerificationRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error {
_, err := r.pool.Exec(ctx, `
INSERT INTO email_verification_tokens (token_hash, user_id, expires_at)
VALUES ($1, $2, $3)
`, hash, userID, expiresAt)
return err
}
// Consume atomically marks the token as used and returns the owning user_id.
// Returns ErrAuthTokenNotFound / ErrAuthTokenConsumed / ErrAuthTokenExpired.
func (r *EmailVerificationRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) {
const q = `
UPDATE email_verification_tokens
SET consumed_at = now()
WHERE token_hash = $1
AND consumed_at IS NULL
AND expires_at > now()
RETURNING user_id
`
var uid uuid.UUID
if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool,
"SELECT consumed_at, expires_at FROM email_verification_tokens WHERE token_hash=$1",
hash)
}
return uuid.Nil, err
}
return uid, nil
}
// PasswordResetRepo manages single-use password-reset tokens.
type PasswordResetRepo struct {
pool *pgxpool.Pool
}
func NewPasswordResetRepo(db *DB) *PasswordResetRepo {
return &PasswordResetRepo{pool: db.Pool}
}
func (r *PasswordResetRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error {
_, err := r.pool.Exec(ctx, `
INSERT INTO password_reset_tokens (token_hash, user_id, expires_at)
VALUES ($1, $2, $3)
`, hash, userID, expiresAt)
return err
}
func (r *PasswordResetRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) {
const q = `
UPDATE password_reset_tokens
SET consumed_at = now()
WHERE token_hash = $1
AND consumed_at IS NULL
AND expires_at > now()
RETURNING user_id
`
var uid uuid.UUID
if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool,
"SELECT consumed_at, expires_at FROM password_reset_tokens WHERE token_hash=$1",
hash)
}
return uuid.Nil, err
}
return uid, nil
}
// RefreshTokenRepo manages refresh-token rows. Refresh tokens are rotated:
// every refresh issues a new token and revokes the old one, recording the
// chain in `replaced_by` so we can detect replay (a revoked token being
// presented again triggers a family-wide revocation).
type RefreshTokenRepo struct {
pool *pgxpool.Pool
}
func NewRefreshTokenRepo(db *DB) *RefreshTokenRepo {
return &RefreshTokenRepo{pool: db.Pool}
}
type RefreshToken struct {
Hash string
UserID uuid.UUID
ExpiresAt time.Time
RevokedAt *time.Time
ReplacedBy *string
UserAgent string
IPAddress *netip.Addr
CreatedAt time.Time
}
type CreateRefreshTokenParams struct {
Hash string
UserID uuid.UUID
ExpiresAt time.Time
UserAgent string
IPAddress string
}
func (r *RefreshTokenRepo) Create(ctx context.Context, p CreateRefreshTokenParams) error {
ip := parseIP(p.IPAddress)
_, err := r.pool.Exec(ctx, `
INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, NULLIF($4, ''), $5)
`, p.Hash, p.UserID, p.ExpiresAt, p.UserAgent, ip)
return err
}
func (r *RefreshTokenRepo) Get(ctx context.Context, hash string) (*RefreshToken, error) {
const q = `
SELECT token_hash, user_id, expires_at, revoked_at, replaced_by,
COALESCE(user_agent, ''), host(ip_address), created_at
FROM refresh_tokens WHERE token_hash = $1
`
var rt RefreshToken
var ipText *string
if err := r.pool.QueryRow(ctx, q, hash).Scan(
&rt.Hash, &rt.UserID, &rt.ExpiresAt, &rt.RevokedAt, &rt.ReplacedBy,
&rt.UserAgent, &ipText, &rt.CreatedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, domain.ErrAuthTokenNotFound
}
return nil, err
}
if ipText != nil && *ipText != "" {
if addr, err := netip.ParseAddr(*ipText); err == nil {
rt.IPAddress = &addr
}
}
return &rt, nil
}
// Rotate atomically (in a transaction) marks the old token revoked and
// inserts the new one with replaced_by set. Returns ErrAuthTokenNotFound or
// ErrRefreshTokenRevoked if the old token is missing or already revoked.
func (r *RefreshTokenRepo) Rotate(ctx context.Context, oldHash string, next CreateRefreshTokenParams) error {
tx, err := r.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
var revokedAt *time.Time
var userID uuid.UUID
var expiresAt time.Time
err = tx.QueryRow(ctx, `
SELECT user_id, expires_at, revoked_at
FROM refresh_tokens WHERE token_hash = $1 FOR UPDATE
`, oldHash).Scan(&userID, &expiresAt, &revokedAt)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return domain.ErrAuthTokenNotFound
}
return err
}
if revokedAt != nil {
// Replay of a revoked refresh token — revoke the entire family.
if _, err := tx.Exec(ctx, `
UPDATE refresh_tokens SET revoked_at = now()
WHERE user_id = $1 AND revoked_at IS NULL
`, userID); err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
return domain.ErrRefreshTokenRevoked
}
if time.Now().After(expiresAt) {
return domain.ErrAuthTokenExpired
}
if next.UserID != userID {
return errors.New("refresh token user mismatch")
}
ip := parseIP(next.IPAddress)
if _, err := tx.Exec(ctx, `
INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, NULLIF($4, ''), $5)
`, next.Hash, next.UserID, next.ExpiresAt, next.UserAgent, ip); err != nil {
return err
}
if _, err := tx.Exec(ctx, `
UPDATE refresh_tokens SET revoked_at = now(), replaced_by = $2
WHERE token_hash = $1
`, oldHash, next.Hash); err != nil {
return err
}
return tx.Commit(ctx)
}
func (r *RefreshTokenRepo) Revoke(ctx context.Context, hash string) error {
tag, err := r.pool.Exec(ctx, `
UPDATE refresh_tokens SET revoked_at = now()
WHERE token_hash = $1 AND revoked_at IS NULL
`, hash)
if err != nil {
return err
}
if tag.RowsAffected() == 0 {
return domain.ErrAuthTokenNotFound
}
return nil
}
func (r *RefreshTokenRepo) RevokeAllForUser(ctx context.Context, userID uuid.UUID) error {
_, err := r.pool.Exec(ctx, `
UPDATE refresh_tokens SET revoked_at = now()
WHERE user_id = $1 AND revoked_at IS NULL
`, userID)
return err
}
func parseIP(s string) any {
if s == "" {
return nil
}
addr, err := netip.ParseAddr(s)
if err != nil {
return nil
}
return addr.String()
}
func classifyAuthTokenLookup(ctx context.Context, pool *pgxpool.Pool, q, hash string) error {
var consumedAt *time.Time
var expiresAt time.Time
if err := pool.QueryRow(ctx, q, hash).Scan(&consumedAt, &expiresAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return domain.ErrAuthTokenNotFound
}
return err
}
if consumedAt != nil {
return domain.ErrAuthTokenConsumed
}
if time.Now().After(expiresAt) {
return domain.ErrAuthTokenExpired
}
return domain.ErrAuthTokenNotFound
}
+30 -12
View File
@@ -78,6 +78,24 @@ func (r *EventRepo) Get(ctx context.Context, id uuid.UUID) (*domain.Event, error
return ev, nil
}
// GetForHost is the authz-aware variant of Get. It returns ErrEventNotFound
// when the event either doesn't exist or doesn't belong to the host — by
// merging both cases we avoid leaking existence on cross-tenant lookups.
func (r *EventRepo) GetForHost(ctx context.Context, id, hostID uuid.UUID) (*domain.Event, error) {
const q = `
SELECT id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at
FROM events WHERE id = $1 AND host_id = $2
`
ev, err := scanEvent(r.pool.QueryRow(ctx, q, id, hostID))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, domain.ErrEventNotFound
}
return nil, err
}
return ev, nil
}
func (r *EventRepo) List(ctx context.Context, hostID uuid.UUID, limit, offset int) ([]*domain.Event, error) {
if limit <= 0 || limit > 200 {
limit = 50
@@ -132,18 +150,18 @@ type UpdateEventParams struct {
Status *domain.EventStatus
}
func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParams) (*domain.Event, error) {
func (r *EventRepo) Update(ctx context.Context, id, hostID uuid.UUID, p UpdateEventParams) (*domain.Event, error) {
const q = `
UPDATE events SET
name = COALESCE($2, name),
slug = COALESCE($3, slug),
event_date = COALESCE($4, event_date),
venue = COALESCE($5, venue),
max_capacity = COALESCE($6, max_capacity),
settings = COALESCE($7, settings),
status = COALESCE($8, status),
name = COALESCE($3, name),
slug = COALESCE($4, slug),
event_date = COALESCE($5, event_date),
venue = COALESCE($6, venue),
max_capacity = COALESCE($7, max_capacity),
settings = COALESCE($8, settings),
status = COALESCE($9, status),
updated_at = now()
WHERE id = $1
WHERE id = $1 AND host_id = $2
RETURNING id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at
`
@@ -156,7 +174,7 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam
settingsJSON = b
}
row := r.pool.QueryRow(ctx, q, id,
row := r.pool.QueryRow(ctx, q, id, hostID,
p.Name, p.Slug, p.EventDate, p.Venue, p.MaxCapacity, settingsJSON, p.Status,
)
ev, err := scanEvent(row)
@@ -173,8 +191,8 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam
return ev, nil
}
func (r *EventRepo) Delete(ctx context.Context, id uuid.UUID) error {
tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1`, id)
func (r *EventRepo) Delete(ctx context.Context, id, hostID uuid.UUID) error {
tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1 AND host_id = $2`, id, hostID)
if err != nil {
return err
}
+212
View File
@@ -3,6 +3,8 @@ package storage
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
@@ -88,6 +90,87 @@ func (r *GuestRepo) ListByEvent(ctx context.Context, eventID uuid.UUID, limit, o
return out, rows.Err()
}
// BulkImportRow is one normalised guest in a CSV import batch.
type BulkImportRow struct {
Name string
Email string // empty if absent
Phone string // empty if absent
PlusOnes int
}
// BulkImportResult reports the outcome of a single import call. The
// SkippedEmails slice records the addresses we silently dropped because a
// guest already exists on the event — useful for the success summary.
type BulkImportResult struct {
Added int
Skipped int
SkippedEmails []string
}
// BulkImportGuests inserts up to len(rows) guest rows into the event in a
// single transaction. Rows whose email matches an existing guest on the
// event are skipped (idempotent re-imports). Within the batch, duplicate
// emails after the first are also skipped. Either the entire batch
// commits or none of it does.
//
// Empty email is treated as "no email" and not deduped — those rows
// always insert (the host might be entering phone-only guests).
func (r *GuestRepo) BulkImportGuests(ctx context.Context, eventID uuid.UUID, rows []BulkImportRow) (BulkImportResult, error) {
res := BulkImportResult{}
if len(rows) == 0 {
return res, nil
}
tx, err := r.pool.Begin(ctx)
if err != nil {
return res, err
}
defer tx.Rollback(ctx)
// Fetch existing emails on the event into a set for O(1) dedup.
existing := map[string]struct{}{}
exRows, err := tx.Query(ctx,
`SELECT lower(email) FROM guests WHERE event_id = $1 AND email IS NOT NULL AND email <> ''`,
eventID)
if err != nil {
return res, fmt.Errorf("load existing emails: %w", err)
}
for exRows.Next() {
var e string
if err := exRows.Scan(&e); err != nil {
exRows.Close()
return res, err
}
existing[e] = struct{}{}
}
exRows.Close()
const ins = `
INSERT INTO guests (event_id, name, email, phone, plus_ones)
VALUES ($1, $2, NULLIF($3, ''), NULLIF($4, ''), $5)
`
for _, row := range rows {
email := strings.ToLower(strings.TrimSpace(row.Email))
if email != "" {
if _, dup := existing[email]; dup {
res.Skipped++
res.SkippedEmails = append(res.SkippedEmails, email)
continue
}
existing[email] = struct{}{}
}
if _, err := tx.Exec(ctx, ins, eventID, row.Name, email, row.Phone, row.PlusOnes); err != nil {
return BulkImportResult{}, fmt.Errorf("insert guest %q: %w", row.Name, err)
}
res.Added++
}
if err := tx.Commit(ctx); err != nil {
return BulkImportResult{}, err
}
return res, nil
}
func scanGuest(s rowScanner) (*domain.Guest, error) {
var g domain.Guest
err := s.Scan(
@@ -100,6 +183,135 @@ func scanGuest(s rowScanner) (*domain.Guest, error) {
return &g, nil
}
// UpdateGuestParams patches a guest. Nil fields are left untouched.
// An empty string for Email / Phone clears the column to NULL, matching
// the frontend "clear this field" UX.
type UpdateGuestParams struct {
Name *string
Email *string
Phone *string
PlusOnes *int
}
// Update applies the patch to (guestID, eventID). Event scoping in the
// WHERE clause prevents a host from patching guests on another host's
// event even if they guess the guest_id. Returns ErrGuestNotFound when
// the guest doesn't exist on the event.
func (r *GuestRepo) Update(ctx context.Context, eventID, guestID uuid.UUID, p UpdateGuestParams) (*domain.Guest, error) {
sets := []string{}
args := []any{guestID, eventID}
add := func(col string, val any) {
args = append(args, val)
sets = append(sets, fmt.Sprintf("%s = $%d", col, len(args)))
}
if p.Name != nil {
add("name", strings.TrimSpace(*p.Name))
}
if p.Email != nil {
if strings.TrimSpace(*p.Email) == "" {
sets = append(sets, "email = NULL")
} else {
add("email", strings.ToLower(strings.TrimSpace(*p.Email)))
}
}
if p.Phone != nil {
if strings.TrimSpace(*p.Phone) == "" {
sets = append(sets, "phone = NULL")
} else {
add("phone", strings.TrimSpace(*p.Phone))
}
}
if p.PlusOnes != nil {
add("plus_ones", *p.PlusOnes)
}
if len(sets) == 0 {
return r.Get(ctx, guestID)
}
q := fmt.Sprintf(`
UPDATE guests SET %s
WHERE id = $1 AND event_id = $2
RETURNING id, event_id, name, email, phone, plus_ones, dietary_notes, table_number, created_at
`, strings.Join(sets, ", "))
g, err := scanGuest(r.pool.QueryRow(ctx, q, args...))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, domain.ErrGuestNotFound
}
return nil, err
}
return g, nil
}
// Delete removes a guest from an event. Cascade-deletes any tokens,
// rsvps, access_logs, and notifications tied to the guest. Event scoping
// in the WHERE clause stops cross-tenant deletes.
func (r *GuestRepo) Delete(ctx context.Context, eventID, guestID uuid.UUID) error {
tag, err := r.pool.Exec(ctx,
`DELETE FROM guests WHERE id = $1 AND event_id = $2`,
guestID, eventID)
if err != nil {
return err
}
if tag.RowsAffected() == 0 {
return domain.ErrGuestNotFound
}
return nil
}
// GuestForInvitation is the minimum data the bulk-invite path needs about
// each candidate. Pulled in a single query joined against tokens so the
// caller knows up-front who's already received an invitation.
type GuestForInvitation struct {
ID uuid.UUID
Name string
Email string // empty when the guest has no email on file
HasToken bool
}
// ListGuestsForInvitation returns every guest on `eventID`, joined with
// the tokens table so the caller can skip guests that already have one.
// When `onlyIDs` is non-nil/non-empty, the result is filtered to that
// subset (used for explicit selection in the bulk-send UI).
func (r *GuestRepo) ListGuestsForInvitation(ctx context.Context, eventID uuid.UUID, onlyIDs []uuid.UUID) ([]GuestForInvitation, error) {
var (
rows pgx.Rows
err error
)
if len(onlyIDs) == 0 {
rows, err = r.pool.Query(ctx, `
SELECT g.id, g.name, COALESCE(g.email,''),
(t.id IS NOT NULL) AS has_token
FROM guests g
LEFT JOIN tokens t ON t.guest_id = g.id
WHERE g.event_id = $1
ORDER BY g.created_at ASC
`, eventID)
} else {
rows, err = r.pool.Query(ctx, `
SELECT g.id, g.name, COALESCE(g.email,''),
(t.id IS NOT NULL) AS has_token
FROM guests g
LEFT JOIN tokens t ON t.guest_id = g.id
WHERE g.event_id = $1 AND g.id = ANY($2)
ORDER BY g.created_at ASC
`, eventID, onlyIDs)
}
if err != nil {
return nil, err
}
defer rows.Close()
var out []GuestForInvitation
for rows.Next() {
var g GuestForInvitation
if err := rows.Scan(&g.ID, &g.Name, &g.Email, &g.HasToken); err != nil {
return nil, err
}
out = append(out, g)
}
return out, rows.Err()
}
// GuestWithRSVP is the dashboard view: a guest plus the RSVP submitted
// against their token, if any. RSVP fields are nil when no response yet.
type GuestWithRSVP struct {

Some files were not shown because too many files have changed in this diff Show More