feat: ship Tier 1 — auth, authz, rate limits, real notifications, CSV import, billing, backups/DR, privacy
Closes every block in docs/TIER1_PLAN.md from the Claude-scope side. The
homelab / cloud setup steps (SES verification, restore drill, lawyer-
drafted ToS) remain operator-owned but are unblocked.
Block A — Authentication
- Migration 0003: password_hash, email_verified, email_verification_tokens,
password_reset_tokens, refresh_tokens (with replaced_by family chain).
- Bcrypt hasher, HS256 JWT signer, single-use refresh tokens with rotation
+ replay-detection (revokes the family on reuse).
- /auth/signup, /login, /refresh, /logout, /verify-email,
/forgot-password, /reset-password — enumeration-safe.
- requireAuth middleware + GET /me.
- Frontend useAuth/useApi with auto-refresh-on-401, login/signup/verify/
forgot/reset pages, route-guard middleware.
Block B — Authorisation
- EventRepo.GetForHost; Update/Delete scoped by host_id.
- All host routes behind requireAuth + ownership; cross-tenant returns
404 (no enumeration). ?host_id removed.
- WS auth via short-lived single-use tickets (POST /auth/ws-ticket).
- Tests: TestCrossTenantIsolation — 9 probes.
Block C — Rate limiting
- Redis sliding-window via Lua (atomic ZADD+ZCARD+PEXPIRE).
- Per-route limits matching the plan (signup IP, login IP+email, RSVP/
access by token, events/guests/tokens by user_id).
- 429 with Retry-After header and JSON body.
- Auth lockout: 5 failed logins → account locked, only password reset
clears it.
- Frontend: useErrMessage normalises 429 + locked messaging.
Block D — Real notifications
- Migration 0004: provider_message_id, bounce_type, complained columns
+ unsubscribes (CITEXT) suppression table.
- Branded HTML + plaintext templates for verification, reset, invitation,
confirmation, reminder. Per-page templates avoid html/template's
contextual-escape collisions.
- Senders: SESv2, Twilio (SMS), SMTP (Mailpit-friendly), Resend HTTP.
- PickEmailSender priority Resend > SMTP > SES > Log — system boots
cleanly in dev with Mailpit; production flips one env var.
- Webhook endpoints (Twilio status + SES SNS) — bounces add to suppression;
signature verification stubbed pending creds.
- Auto-send: POST /tokens publishes invitation.send; notifier renders +
delivers via the configured backend; suppression list honoured.
- Bulk + per-row invitation flow: POST /events/{id}/guests/invitations/bulk
returns per-guest tokens so phone-only guests can be SMS'd manually.
- Unsubscribe: signed HMAC token (no TTL) + /unsubscribe/[token] page.
- WhatsApp Option A+: wa.me click-to-chat wizard with per-guest progress
tracking, isLikelyE164 validation, edit-from-wizard.
- Token rotate (POST /tokens/rotate) invalidates the old URL — used by
the regenerate-link flow.
- Mailpit added to docker-compose for dev inbox.
Block E — CSV import
- Streaming parser: tolerant header detection, UTF-8 BOM + UTF-16 LE/BE
decoding, row-level validation, 5,000-row cap.
- Strict E.164 phone validation with helpful error message.
- POST /preview + /import + GET /template; preview UI on event page;
atomic per-batch with dedup on existing emails.
Phone capture across UI
- PhoneInput component: country picker (~50 ISO codes) + national input +
live E.164 preview + inline length validation.
- Used in Add Guest and Edit Guest modals. Smart paste-handling extracts
country code from full E.164 strings.
Block F — Billing (Stripe)
- Migration 0005: subscriptions table (user_id → tier/status/period_end +
Stripe customer/sub ids). Partial unique index keeps one granting sub
per user.
- internal/billing: Tier + Limits model (Free 1/50, Pro 10/1000, Business
∞/5000), Stripe SDK wrapper with IgnoreAPIVersionMismatch for newer
account API versions.
- /billing/checkout-session, /billing/portal, /billing/status,
/webhooks/stripe (signature-verified, lifecycle events).
- Tier enforcement: 402 on POST /events, /guests, /import with
{error, reason, tier, used, limit, upgrade_url} body.
- Frontend: useBilling composable, /dashboard/billing page (current plan,
usage bars, tier cards), global UpgradeModal triggered by useApi's
402 interceptor.
- Customer portal kept for self-service cancel/payment-method changes.
Block G — Backups & DR (application side)
- Every migration has a tested .down.sql.
- TestMigrationRoundtrip applies all ups → all downs → all ups against a
fresh container; catches asymmetric down migrations.
- cmd/restore-verify: 28-check post-restore invariant tool (schema
presence, no orphans across 10 FK relationships, email uniqueness,
single-active subscription, row-count snapshot).
- docs/RUNBOOK_RESTORE.md: 9-step restore procedure with RTO/RPO
targets, drill instructions, rollback path.
Block H — Privacy compliance (application side)
- Migration 0006: deleted_at + terms_accepted_at + privacy_policy_accepted_at
on users. Partial index on email for live-only uniqueness.
- GET /me/data-export — synchronous JSON dump (user, events, guests,
tokens, rsvps, access_logs, notifications).
- DELETE /me — soft-delete with PII scrub + refresh-token revocation;
re-signup with same email works.
- POST /me/accept-terms — idempotent consent recording.
- Frontend /privacy + /terms placeholder pages with substantive (pending
legal review) copy; footer links; signup terms checkbox; TermsGateModal
for accounts created before the rollout; export + delete buttons on
/dashboard/billing.
Tests
- All migrations verified up/down/up.
- Integration suite: TestE2EHappyPath, TestAuthFlow, TestCrossTenantIsolation,
TestRateLimitSignup, TestLoginLockout, TestUnsubscribeFlow,
TestSESBounceWebhook, TestTwilioStatusWebhook, TestCsvImportFlow,
TestCsvImportAtomicRollback, TestBulkIssueInvitations, TestBulkIssueExplicitSubset,
TestTokenIssuePublishesInvitation, TestTokenIssueWithoutGuestEmailSkipsInvitation,
TestGuestUpdate, TestGuestDelete, TestTokenRotate, TestSMTPSenderAgainstMailpit,
TestFreeTierEventLimit, TestFreeTierGuestLimit, TestBusinessTierBypassesLimits,
TestDataExport, TestDeleteMe, TestAcceptTerms, TestMigrationRoundtrip.
Full suite runs in ~120s against real Postgres + NATS + Redis + Mailpit.
- Unit suite green across internal/auth, internal/csvimport,
internal/notification, internal/ratelimit, internal/domain.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,9 @@ coverage.*
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Agent / per-developer config (launch.json with absolute paths, worktree state).
|
||||
.claude/
|
||||
|
||||
.DS_Store
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
+101
-10
@@ -11,10 +11,15 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/api"
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/config"
|
||||
"github.com/alchemistkay/guestguard/internal/fraud"
|
||||
"github.com/alchemistkay/guestguard/internal/natspub"
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
@@ -56,6 +61,16 @@ func run() error {
|
||||
}
|
||||
defer natsClient.Close()
|
||||
|
||||
logger.Info("connecting to redis", "addr", cfg.RedisAddr)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: cfg.RedisAddr})
|
||||
if err := rdb.Ping(rootCtx).Err(); err != nil {
|
||||
logger.Warn("redis ping failed — rate limits + lockout disabled", "err", err)
|
||||
_ = rdb.Close()
|
||||
rdb = nil
|
||||
} else {
|
||||
defer rdb.Close()
|
||||
}
|
||||
|
||||
logger.Info("dialing fraud engine", "addr", cfg.FraudGRPCAddr)
|
||||
fraudClient, err := fraud.Dial(rootCtx, cfg.FraudGRPCAddr, cfg.FraudGRPCTimeout, logger)
|
||||
if err != nil {
|
||||
@@ -118,17 +133,93 @@ func run() error {
|
||||
}
|
||||
defer rsvpConsumeCtx.Stop()
|
||||
|
||||
// Notification senders. If SES creds are configured, route auth +
|
||||
// guest emails through SES. Otherwise the log stub keeps the dev flow
|
||||
// (verification link in API logs) intact.
|
||||
tpls, err := notification.NewTemplates()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
suppressions := notification.NewSuppressionRepo(db)
|
||||
notifRepo := notification.NewRepo(db)
|
||||
unsubSigner := notification.NewUnsubscribeSigner(cfg.UnsubscribeSecret)
|
||||
|
||||
emailSenderCombined, backend, err := notification.PickEmailSender(rootCtx, notification.EmailSenderConfig{
|
||||
Resend: notification.ResendConfig{
|
||||
APIKey: cfg.ResendAPIKey,
|
||||
FromEmail: cfg.ResendFromEmail,
|
||||
FromName: cfg.ResendFromName,
|
||||
},
|
||||
SMTP: notification.SMTPConfig{
|
||||
Host: cfg.SMTPHost,
|
||||
Port: cfg.SMTPPort,
|
||||
Username: cfg.SMTPUsername,
|
||||
Password: cfg.SMTPPassword,
|
||||
FromEmail: cfg.SMTPFromEmail,
|
||||
FromName: cfg.SMTPFromName,
|
||||
TLS: cfg.SMTPTLS,
|
||||
},
|
||||
SES: notification.SESConfig{
|
||||
Region: cfg.SESRegion,
|
||||
FromEmail: cfg.SESFromEmail,
|
||||
FromName: cfg.SESFromName,
|
||||
ConfigurationSet: cfg.SESConfigurationSet,
|
||||
PublicBaseURL: cfg.PublicBaseURL,
|
||||
},
|
||||
}, tpls, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("email backend selected", "backend", backend)
|
||||
var emailSender auth.EmailSender = emailSenderCombined
|
||||
|
||||
stripeClient, err := billing.NewClient(billing.Config{
|
||||
SecretKey: cfg.StripeSecretKey,
|
||||
WebhookSecret: cfg.StripeWebhookSecret,
|
||||
PriceProMonthly: cfg.StripePricePro,
|
||||
PriceBusiness: cfg.StripePriceBusiness,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stripeClient != nil && stripeClient.Enabled() {
|
||||
logger.Info("billing enabled via stripe")
|
||||
} else {
|
||||
logger.Info("billing disabled — free tier limits apply to all users")
|
||||
}
|
||||
|
||||
apiSrv, err := api.NewServer(api.ServerDeps{
|
||||
Logger: logger,
|
||||
DB: db,
|
||||
Hub: hub,
|
||||
AccessPublisher: natsClient,
|
||||
RSVPPublisher: natsClient,
|
||||
InvitationPublisher: natsClient,
|
||||
FraudScorer: fraudClient,
|
||||
TokenTTL: cfg.TokenTTL,
|
||||
JWTSecret: cfg.JWTSecret,
|
||||
JWTIssuer: cfg.JWTIssuer,
|
||||
AccessTokenTTL: cfg.AccessTokenTTL,
|
||||
RefreshTokenTTL: cfg.RefreshTokenTTL,
|
||||
EmailVerificationTTL: cfg.EmailVerificationTTL,
|
||||
PasswordResetTTL: cfg.PasswordResetTTL,
|
||||
PublicBaseURL: cfg.PublicBaseURL,
|
||||
RefreshCookieDomain: cfg.RefreshCookieDomain,
|
||||
RefreshCookieSecure: cfg.RefreshCookieSecure,
|
||||
Redis: rdb,
|
||||
EmailSender: emailSender,
|
||||
NotificationRepo: notifRepo,
|
||||
SuppressionRepo: suppressions,
|
||||
UnsubscribeSigner: unsubSigner,
|
||||
StripeClient: stripeClient,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: cfg.HTTPAddr,
|
||||
Handler: api.NewServer(api.ServerDeps{
|
||||
Logger: logger,
|
||||
DB: db,
|
||||
Hub: hub,
|
||||
AccessPublisher: natsClient,
|
||||
RSVPPublisher: natsClient,
|
||||
FraudScorer: fraudClient,
|
||||
TokenTTL: cfg.TokenTTL,
|
||||
}).Handler(),
|
||||
Addr: cfg.HTTPAddr,
|
||||
Handler: apiSrv.Handler(),
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 0, // 0 lets WS connections live; per-request handlers still bound by their own ctx
|
||||
|
||||
+141
-1
@@ -48,7 +48,61 @@ func run() error {
|
||||
defer natsClient.Close()
|
||||
|
||||
repo := notification.NewRepo(db)
|
||||
sender := notification.LogSender{}
|
||||
suppressions := notification.NewSuppressionRepo(db)
|
||||
tpls, err := notification.NewTemplates()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Email dispatcher: Resend > SMTP > SES > log stub, same picker as cmd/api.
|
||||
combinedEmail, backend, err := notification.PickEmailSender(rootCtx, notification.EmailSenderConfig{
|
||||
Resend: notification.ResendConfig{
|
||||
APIKey: cfg.ResendAPIKey,
|
||||
FromEmail: cfg.ResendFromEmail,
|
||||
FromName: cfg.ResendFromName,
|
||||
},
|
||||
SMTP: notification.SMTPConfig{
|
||||
Host: cfg.SMTPHost,
|
||||
Port: cfg.SMTPPort,
|
||||
Username: cfg.SMTPUsername,
|
||||
Password: cfg.SMTPPassword,
|
||||
FromEmail: cfg.SMTPFromEmail,
|
||||
FromName: cfg.SMTPFromName,
|
||||
TLS: cfg.SMTPTLS,
|
||||
},
|
||||
SES: notification.SESConfig{
|
||||
Region: cfg.SESRegion,
|
||||
FromEmail: cfg.SESFromEmail,
|
||||
FromName: cfg.SESFromName,
|
||||
ConfigurationSet: cfg.SESConfigurationSet,
|
||||
PublicBaseURL: cfg.PublicBaseURL,
|
||||
},
|
||||
}, tpls, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("email backend selected", "backend", backend)
|
||||
emailSender := notification.NewEmailSender(combinedEmail, suppressions)
|
||||
|
||||
// SMS: Twilio when creds are set, otherwise no-op log sender.
|
||||
var smsSender notification.Sender
|
||||
if cfg.TwilioAccountSID != "" && cfg.TwilioAuthToken != "" && cfg.TwilioFromNumber != "" {
|
||||
t, err := notification.NewTwilioSender(notification.TwilioConfig{
|
||||
AccountSID: cfg.TwilioAccountSID,
|
||||
AuthToken: cfg.TwilioAuthToken,
|
||||
FromNumber: cfg.TwilioFromNumber,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
smsSender = t
|
||||
logger.Info("twilio sms sender configured", "from", cfg.TwilioFromNumber)
|
||||
} else {
|
||||
smsSender = notification.LogSender{}
|
||||
logger.Info("twilio not configured — SMS will use the log stub")
|
||||
}
|
||||
|
||||
sender := notification.NewRouter(emailSender, smsSender)
|
||||
|
||||
rsvpSub, err := natspub.NewRSVPConfirmedSubscriber(
|
||||
rootCtx, natsClient, "notifier-rsvp-confirmed",
|
||||
@@ -82,6 +136,22 @@ func run() error {
|
||||
}
|
||||
defer fraudCC.Stop()
|
||||
|
||||
invitationSub, err := natspub.NewInvitationSendSubscriber(
|
||||
rootCtx, natsClient, "notifier-invitation-send",
|
||||
func(ctx context.Context, evt natspub.InvitationSend) error {
|
||||
return handleInvitationSend(ctx, logger, repo, sender, evt)
|
||||
},
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
invitationCC, err := invitationSub.Start(rootCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer invitationCC.Stop()
|
||||
|
||||
logger.Info("notifier started")
|
||||
<-rootCtx.Done()
|
||||
logger.Info("notifier shutting down")
|
||||
@@ -201,6 +271,76 @@ func handleFraudScored(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvitationSend renders the invitation template + sends through
|
||||
// the configured sender (Resend in prod, Mailpit in dev), then writes a
|
||||
// notification row so the host can audit the delivery history.
|
||||
func handleInvitationSend(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
repo *notification.Repo,
|
||||
sender notification.Sender,
|
||||
evt natspub.InvitationSend,
|
||||
) error {
|
||||
if evt.GuestEmail == "" {
|
||||
// Nothing to do — host-managed delivery.
|
||||
return nil
|
||||
}
|
||||
|
||||
eventDate := ""
|
||||
if !evt.EventDate.IsZero() {
|
||||
eventDate = evt.EventDate.Format("Mon, 02 Jan 2006 · 15:04")
|
||||
}
|
||||
|
||||
msg := notification.OutboundMessage{
|
||||
GuestID: evt.GuestID,
|
||||
Channel: notification.ChannelEmail,
|
||||
Type: notification.TypeInvitation,
|
||||
Subject: "You're invited — " + evt.EventName,
|
||||
Metadata: map[string]any{
|
||||
"to": evt.GuestEmail,
|
||||
"GuestName": evt.GuestName,
|
||||
"HostName": evt.HostName,
|
||||
"EventName": evt.EventName,
|
||||
"Venue": evt.Venue,
|
||||
"EventDate": eventDate,
|
||||
"Link": evt.Link,
|
||||
},
|
||||
}
|
||||
|
||||
providerID, sendErr := sender.Send(ctx, msg)
|
||||
status := notification.StatusSent
|
||||
errStr := ""
|
||||
if sendErr != nil {
|
||||
status = notification.StatusFailed
|
||||
errStr = sendErr.Error()
|
||||
}
|
||||
|
||||
id, err := repo.Record(ctx, notification.RecordParams{
|
||||
GuestID: evt.GuestID,
|
||||
Channel: msg.Channel,
|
||||
Type: msg.Type,
|
||||
Status: status,
|
||||
ProviderMessageID: providerID,
|
||||
Error: errStr,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("invitation dispatched",
|
||||
"notification_id", id,
|
||||
"guest_id", evt.GuestID,
|
||||
"event_id", evt.EventID,
|
||||
"to", evt.GuestEmail,
|
||||
"status", status,
|
||||
"provider_message_id", providerID,
|
||||
)
|
||||
if sendErr != nil {
|
||||
return sendErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func levelFor(env string) slog.Level {
|
||||
if env == "development" {
|
||||
return slog.LevelDebug
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
// restore-verify is a post-restore sanity tool. Point it at a freshly
|
||||
// restored Postgres instance and it asserts that the schema and data
|
||||
// are coherent — no orphan rows, expected uniqueness invariants hold,
|
||||
// every table is present. Exits non-zero on any failure so it slots
|
||||
// into a "restore drill" runbook step.
|
||||
//
|
||||
// Usage:
|
||||
// GG_DATABASE_URL=postgres://... ./restore-verify [--verbose]
|
||||
//
|
||||
// The intent is "would I bet my Sunday on this restore being usable?".
|
||||
// Failing fast here keeps a bad restore from being promoted to traffic.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
func main() {
|
||||
verbose := flag.Bool("verbose", false, "print every check's result, not just failures")
|
||||
flag.Parse()
|
||||
|
||||
dsn := os.Getenv("GG_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
fmt.Fprintln(os.Stderr, "restore-verify: GG_DATABASE_URL is required")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "restore-verify: connect: %v\n", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "restore-verify: ping: %v\n", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
fmt.Println("restore-verify: checking", maskDSN(dsn))
|
||||
fmt.Println()
|
||||
|
||||
checks := allChecks()
|
||||
var failed []string
|
||||
for _, c := range checks {
|
||||
result, err := c.fn(ctx, pool)
|
||||
if err != nil {
|
||||
failed = append(failed, c.name)
|
||||
fmt.Printf(" ✗ %-50s FAIL: %v\n", c.name, err)
|
||||
continue
|
||||
}
|
||||
if *verbose {
|
||||
fmt.Printf(" ✓ %-50s %s\n", c.name, result)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
if len(failed) > 0 {
|
||||
fmt.Printf("FAILED: %d check%s — %s\n",
|
||||
len(failed), pluralS(len(failed)), strings.Join(failed, ", "))
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("OK: all %d checks passed\n", len(checks))
|
||||
}
|
||||
|
||||
type check struct {
|
||||
name string
|
||||
fn func(ctx context.Context, pool *pgxpool.Pool) (string, error)
|
||||
}
|
||||
|
||||
func allChecks() []check {
|
||||
return []check{
|
||||
// --- schema presence ---
|
||||
tableExists("users"),
|
||||
tableExists("events"),
|
||||
tableExists("guests"),
|
||||
tableExists("tokens"),
|
||||
tableExists("rsvps"),
|
||||
tableExists("access_logs"),
|
||||
tableExists("notifications"),
|
||||
tableExists("schema_migrations"),
|
||||
tableExists("email_verification_tokens"),
|
||||
tableExists("password_reset_tokens"),
|
||||
tableExists("refresh_tokens"),
|
||||
tableExists("unsubscribes"),
|
||||
tableExists("subscriptions"),
|
||||
|
||||
// --- migrations applied ---
|
||||
{
|
||||
name: "schema_migrations: 5+ migrations applied",
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
var n int
|
||||
if err := pool.QueryRow(ctx,
|
||||
`SELECT count(*) FROM schema_migrations`).Scan(&n); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n < 5 {
|
||||
return "", fmt.Errorf("only %d migrations recorded — incomplete restore", n)
|
||||
}
|
||||
return fmt.Sprintf("%d migrations", n), nil
|
||||
},
|
||||
},
|
||||
|
||||
// --- referential integrity (FK constraints catch most of this,
|
||||
// but a bad logical dump or partial restore can sneak rows in) ---
|
||||
noOrphans("events", "host_id", "users", "id"),
|
||||
noOrphans("guests", "event_id", "events", "id"),
|
||||
noOrphans("tokens", "guest_id", "guests", "id"),
|
||||
noOrphans("rsvps", "guest_id", "guests", "id"),
|
||||
noOrphans("access_logs", "guest_id", "guests", "id"),
|
||||
noOrphans("notifications", "guest_id", "guests", "id"),
|
||||
noOrphans("email_verification_tokens", "user_id", "users", "id"),
|
||||
noOrphans("password_reset_tokens", "user_id", "users", "id"),
|
||||
noOrphans("refresh_tokens", "user_id", "users", "id"),
|
||||
noOrphans("subscriptions", "user_id", "users", "id"),
|
||||
|
||||
// --- domain invariants ---
|
||||
{
|
||||
name: "users.email is unique (case-insensitive)",
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
var dupes int
|
||||
err := pool.QueryRow(ctx, `
|
||||
SELECT count(*) FROM (
|
||||
SELECT lower(email) FROM users GROUP BY lower(email) HAVING count(*) > 1
|
||||
) t
|
||||
`).Scan(&dupes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if dupes > 0 {
|
||||
return "", fmt.Errorf("%d duplicate email(s) found", dupes)
|
||||
}
|
||||
return "no duplicate emails", nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "guests with rsvp_response have an existing rsvp row",
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
var n int
|
||||
err := pool.QueryRow(ctx, `
|
||||
SELECT count(*) FROM rsvps r
|
||||
WHERE NOT EXISTS (SELECT 1 FROM guests g WHERE g.id = r.guest_id)
|
||||
`).Scan(&n)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n > 0 {
|
||||
return "", fmt.Errorf("%d rsvp(s) reference a missing guest", n)
|
||||
}
|
||||
return "0 orphan rsvps", nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "at most one granting subscription per user",
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
var n int
|
||||
err := pool.QueryRow(ctx, `
|
||||
SELECT count(*) FROM (
|
||||
SELECT user_id FROM subscriptions
|
||||
WHERE status IN ('active','past_due','trialing')
|
||||
GROUP BY user_id HAVING count(*) > 1
|
||||
) t
|
||||
`).Scan(&n)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n > 0 {
|
||||
return "", fmt.Errorf("%d user(s) have multiple granting subscriptions", n)
|
||||
}
|
||||
return "all users single-active", nil
|
||||
},
|
||||
},
|
||||
|
||||
// --- soft constraints worth noticing (not failures, but logged) ---
|
||||
{
|
||||
name: "row counts snapshot",
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
parts := []string{}
|
||||
for _, t := range []string{
|
||||
"users", "events", "guests", "tokens", "rsvps",
|
||||
"access_logs", "notifications", "subscriptions",
|
||||
} {
|
||||
var n int
|
||||
if err := pool.QueryRow(ctx,
|
||||
fmt.Sprintf("SELECT count(*) FROM %s", t)).Scan(&n); err != nil {
|
||||
return "", err
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s=%d", t, n))
|
||||
}
|
||||
return strings.Join(parts, " "), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func tableExists(name string) check {
|
||||
return check{
|
||||
name: fmt.Sprintf("table %q exists", name),
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
var exists bool
|
||||
err := pool.QueryRow(ctx,
|
||||
`SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema='public' AND table_name=$1)`,
|
||||
name,
|
||||
).Scan(&exists)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", errors.New("missing")
|
||||
}
|
||||
return "ok", nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func noOrphans(childTable, childFK, parentTable, parentPK string) check {
|
||||
return check{
|
||||
name: fmt.Sprintf("no orphans: %s.%s -> %s.%s", childTable, childFK, parentTable, parentPK),
|
||||
fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) {
|
||||
q := fmt.Sprintf(`
|
||||
SELECT count(*) FROM %s c
|
||||
WHERE c.%s IS NOT NULL
|
||||
AND NOT EXISTS (SELECT 1 FROM %s p WHERE p.%s = c.%s)
|
||||
`, childTable, childFK, parentTable, parentPK, childFK)
|
||||
var n int
|
||||
if err := pool.QueryRow(ctx, q).Scan(&n); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n > 0 {
|
||||
return "", fmt.Errorf("%d orphan row(s)", n)
|
||||
}
|
||||
return "clean", nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func maskDSN(dsn string) string {
|
||||
// Crude: redact password between '://user:' and '@'.
|
||||
at := strings.LastIndex(dsn, "@")
|
||||
scheme := strings.Index(dsn, "://")
|
||||
if at < 0 || scheme < 0 || at <= scheme {
|
||||
return dsn
|
||||
}
|
||||
userInfo := dsn[scheme+3 : at]
|
||||
if colon := strings.Index(userInfo, ":"); colon >= 0 {
|
||||
userInfo = userInfo[:colon] + ":****"
|
||||
}
|
||||
return dsn[:scheme+3] + userInfo + dsn[at:]
|
||||
}
|
||||
|
||||
func pluralS(n int) string {
|
||||
if n == 1 {
|
||||
return ""
|
||||
}
|
||||
return "s"
|
||||
}
|
||||
@@ -15,6 +15,28 @@ services:
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
|
||||
mailpit:
|
||||
image: axllent/mailpit:latest
|
||||
ports:
|
||||
- "1025:1025" # SMTP
|
||||
- "8025:8025" # Web UI
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-qO-", "http://localhost:8025/api/v1/info"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: ["redis-server", "--save", "", "--appendonly", "no"]
|
||||
ports:
|
||||
- "6379:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
|
||||
nats:
|
||||
image: nats:2.10-alpine
|
||||
command:
|
||||
@@ -45,7 +67,25 @@ services:
|
||||
GG_NATS_URL: nats://nats:4222
|
||||
GG_FRAUD_GRPC_ADDR: fraud-engine:9091
|
||||
GG_FRAUD_GRPC_TIMEOUT: 250ms
|
||||
GG_REDIS_ADDR: redis:6379
|
||||
GG_SMTP_HOST: mailpit
|
||||
GG_SMTP_PORT: "1025"
|
||||
GG_SMTP_FROM_EMAIL: noreply@guestguard.local
|
||||
GG_SMTP_FROM_NAME: GuestGuard (dev)
|
||||
# Resend overrides SMTP when GG_RESEND_API_KEY is set in .env at the
|
||||
# project root. Leave unset and Mailpit (above) handles delivery.
|
||||
GG_RESEND_API_KEY: ${GG_RESEND_API_KEY:-}
|
||||
GG_RESEND_FROM_EMAIL: ${GG_RESEND_FROM_EMAIL:-}
|
||||
GG_RESEND_FROM_NAME: ${GG_RESEND_FROM_NAME:-GuestGuard}
|
||||
GG_TOKEN_SECRET: dev-only-insecure-secret-change-me
|
||||
GG_JWT_SECRET: dev-only-insecure-jwt-secret-change-me-32+bytes
|
||||
GG_PUBLIC_BASE_URL: http://localhost:3000
|
||||
# Stripe billing — empty values leave billing disabled. Set in
|
||||
# .env at the project root to enable.
|
||||
GG_STRIPE_SECRET_KEY: ${GG_STRIPE_SECRET_KEY:-}
|
||||
GG_STRIPE_WEBHOOK_SECRET: ${GG_STRIPE_WEBHOOK_SECRET:-}
|
||||
GG_STRIPE_PRICE_PRO: ${GG_STRIPE_PRICE_PRO:-}
|
||||
GG_STRIPE_PRICE_BUSINESS: ${GG_STRIPE_PRICE_BUSINESS:-}
|
||||
ports:
|
||||
- "8080:8080"
|
||||
depends_on:
|
||||
@@ -53,6 +93,8 @@ services:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
fraud-engine:
|
||||
@@ -80,6 +122,14 @@ services:
|
||||
GG_ENV: development
|
||||
GG_DATABASE_URL: postgres://guestguard:guestguard@postgres:5432/guestguard?sslmode=disable
|
||||
GG_NATS_URL: nats://nats:4222
|
||||
GG_PUBLIC_BASE_URL: http://localhost:3000
|
||||
GG_SMTP_HOST: mailpit
|
||||
GG_SMTP_PORT: "1025"
|
||||
GG_SMTP_FROM_EMAIL: noreply@guestguard.local
|
||||
GG_SMTP_FROM_NAME: GuestGuard (dev)
|
||||
GG_RESEND_API_KEY: ${GG_RESEND_API_KEY:-}
|
||||
GG_RESEND_FROM_EMAIL: ${GG_RESEND_FROM_EMAIL:-}
|
||||
GG_RESEND_FROM_NAME: ${GG_RESEND_FROM_NAME:-GuestGuard}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
# Runbook — Postgres restore
|
||||
|
||||
This is the procedure to bring GuestGuard back from a Postgres backup
|
||||
after data loss. It assumes the infra side of Block G (`pg_basebackup` +
|
||||
WAL archiving to S3, daily logical dumps, cross-region replication) is
|
||||
already in place — see the homelab repo for those.
|
||||
|
||||
The application side — migration down-scripts, the [`restore-verify`](../cmd/restore-verify/main.go)
|
||||
tool, and this document — lives here in the GuestGuard repo so it ships
|
||||
in lockstep with the schema.
|
||||
|
||||
---
|
||||
|
||||
## Targets
|
||||
|
||||
| Metric | Target |
|
||||
|---|---|
|
||||
| RTO (recovery time objective) | ≤ 1 hour from "go" decision to traffic-serving |
|
||||
| RPO (recovery point objective) | ≤ 5 minutes of data loss (WAL ships every 60s, S3 PUT every 5min) |
|
||||
|
||||
If RTO is going to slip past 1 hour, escalate per the comms plan in `docs/INCIDENT_RESPONSE.md` (infra repo).
|
||||
|
||||
## When to invoke this
|
||||
|
||||
- Primary Postgres is unreachable AND the standby has also failed
|
||||
- Logical corruption discovered (e.g., a bad migration deleted rows)
|
||||
- Region-wide outage at the primary's location
|
||||
- A "what if we restored last Tuesday" drill (see [Drill](#drill-procedure))
|
||||
|
||||
If only the primary is unreachable and the standby is healthy, promote
|
||||
the standby (separate runbook). Don't use this procedure unnecessarily —
|
||||
restores are expensive.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting:
|
||||
|
||||
- [ ] Decision authority has approved the restore (CTO or on-call lead)
|
||||
- [ ] Read access to the S3 backup bucket: `s3://guestguard-pg-backups`
|
||||
- [ ] `psql`, `pg_basebackup`, `wal-g` (or chosen WAL tool) installed
|
||||
- [ ] Empty target Postgres instance provisioned (Kubernetes Statefulset,
|
||||
RDS, or homelab box — same major version as the backup)
|
||||
- [ ] `GG_DATABASE_URL` env var ready for the new instance
|
||||
- [ ] Maintenance page deployed to the frontend (`/dashboard` returns 503)
|
||||
- [ ] API + notifier pods scaled to 0 (`kubectl scale --replicas=0`)
|
||||
- [ ] This document open in another tab
|
||||
|
||||
## Steps
|
||||
|
||||
### 1. Stop write traffic
|
||||
|
||||
```bash
|
||||
# k8s
|
||||
kubectl scale deployment/guestguard-api --replicas=0
|
||||
kubectl scale deployment/guestguard-notifier --replicas=0
|
||||
|
||||
# Confirm no connections to the (broken) primary
|
||||
kubectl exec -n postgres guestguard-pg-0 -- psql -U postgres -c \
|
||||
"SELECT count(*) FROM pg_stat_activity WHERE datname='guestguard'"
|
||||
```
|
||||
|
||||
If using docker-compose locally: `docker compose stop api notifier`.
|
||||
|
||||
### 2. Identify the recovery point
|
||||
|
||||
Pick the latest backup that's known-good. For corruption scenarios,
|
||||
this may mean going further back than the most recent dump.
|
||||
|
||||
```bash
|
||||
# List base backups (most recent first)
|
||||
wal-g backup-list 2>/dev/null | tail -10
|
||||
|
||||
# Pick the timestamp (e.g. base_000000010000000000000007) and decide
|
||||
# the LSN target if doing point-in-time recovery
|
||||
```
|
||||
|
||||
For corruption: pick the latest backup created **before** the corrupting
|
||||
event. For "ransomware / bad migration", probably 1–2 days back.
|
||||
|
||||
### 3. Restore the base backup
|
||||
|
||||
```bash
|
||||
# Replace BACKUP_NAME with the chosen base
|
||||
wal-g backup-fetch /var/lib/postgresql/data BACKUP_NAME
|
||||
|
||||
# Configure recovery target (omit recovery_target_time for "latest")
|
||||
cat >> /var/lib/postgresql/data/postgresql.conf <<EOF
|
||||
restore_command = 'wal-g wal-fetch "%f" "%p"'
|
||||
recovery_target_time = '2026-05-13 14:30:00 UTC' # set if doing PITR
|
||||
EOF
|
||||
|
||||
touch /var/lib/postgresql/data/recovery.signal
|
||||
```
|
||||
|
||||
### 4. Start Postgres and let it replay WAL
|
||||
|
||||
```bash
|
||||
systemctl start postgresql # or your equivalent
|
||||
|
||||
# Watch the log — should see "consistent recovery state reached"
|
||||
tail -f /var/log/postgresql/postgresql-16-main.log
|
||||
```
|
||||
|
||||
Wait until recovery completes and Postgres is in normal (not recovery)
|
||||
mode:
|
||||
|
||||
```bash
|
||||
psql -c "SELECT pg_is_in_recovery()"
|
||||
# Expected: f (false)
|
||||
```
|
||||
|
||||
### 5. Verify the restored database
|
||||
|
||||
This is the critical gate before any application traffic touches it.
|
||||
|
||||
```bash
|
||||
# Build the verifier (only needed once)
|
||||
go build -o restore-verify ./cmd/restore-verify
|
||||
|
||||
# Run it against the restored instance
|
||||
GG_DATABASE_URL='postgres://guestguard:CHANGE_ME@RESTORED_HOST:5432/guestguard?sslmode=require' \
|
||||
./restore-verify --verbose
|
||||
```
|
||||
|
||||
Expected output: `OK: all N checks passed`. The tool checks:
|
||||
|
||||
- All expected tables exist (users, events, guests, tokens, rsvps, etc.)
|
||||
- All migrations recorded in `schema_migrations`
|
||||
- No orphan rows across the ten FK relationships we care about
|
||||
- `users.email` is still unique (case-insensitive)
|
||||
- No more than one "granting" subscription per user
|
||||
- Row-count snapshot (for sanity, not pass/fail)
|
||||
|
||||
**If any check fails: STOP.** The restore is corrupt — go back to step 2
|
||||
with an earlier backup OR escalate.
|
||||
|
||||
### 6. Apply pending migrations
|
||||
|
||||
If the backup is from before a recent migration that shipped to prod,
|
||||
catch up:
|
||||
|
||||
```bash
|
||||
# The API auto-migrates on boot, but we want to apply migrations
|
||||
# before traffic, so kick a one-off:
|
||||
docker run --rm \
|
||||
-e GG_DATABASE_URL='postgres://...' \
|
||||
ghcr.io/alchemistkay/guestguard-api:latest \
|
||||
/app/api --migrate-only
|
||||
|
||||
# Or via psql, applying each .up.sql in order if you don't have the image:
|
||||
for m in 0001_init 0002_rsvps 0003_auth 0004_notifications_d 0005_billing; do
|
||||
psql -f internal/storage/migrations/${m}.up.sql || break
|
||||
done
|
||||
```
|
||||
|
||||
Run `restore-verify` again after migrations to confirm everything's
|
||||
still coherent.
|
||||
|
||||
### 7. Bring the API back up
|
||||
|
||||
```bash
|
||||
kubectl scale deployment/guestguard-api --replicas=2
|
||||
kubectl scale deployment/guestguard-notifier --replicas=1
|
||||
|
||||
# Watch the logs — expect "http server starting" + "billing enabled via stripe"
|
||||
kubectl logs -f deployment/guestguard-api --tail=20
|
||||
```
|
||||
|
||||
### 8. Smoke test
|
||||
|
||||
- [ ] Hit `/health` → 200
|
||||
- [ ] Sign in as a known test user → dashboard loads, recent events visible
|
||||
- [ ] Create a new event → succeeds, appears in list
|
||||
- [ ] Tail API logs for 5 minutes → no 5xx storms
|
||||
|
||||
### 9. Re-enable traffic
|
||||
|
||||
- [ ] Remove the maintenance page from the frontend
|
||||
- [ ] Announce restoration in the status channel + status page
|
||||
- [ ] Note actual RTO + RPO achieved for the post-mortem
|
||||
|
||||
## Drill procedure
|
||||
|
||||
Run this monthly with no real outage to keep the team's hands warm.
|
||||
|
||||
1. Provision a throwaway Postgres instance (`postgres-drill-YYYYMM`).
|
||||
2. Run steps 2–5 against it (skip 1, 7, 8, 9 — production stays untouched).
|
||||
3. `restore-verify` MUST pass.
|
||||
4. Bonus: spin up an API pointed at the drill DB on a one-off port and
|
||||
walk through the smoke-test scenarios.
|
||||
5. Tear down the drill DB.
|
||||
6. Log the time taken in `docs/RESTORE_DRILL_LOG.md` (or wherever your
|
||||
team tracks operational drills).
|
||||
|
||||
If any step fails during a drill, the production fail-over procedure is
|
||||
**unreliable** — treat as a P1 to fix before the next real failure.
|
||||
|
||||
## Rollback (if restore is wrong)
|
||||
|
||||
If you complete the restore and discover it's the wrong data:
|
||||
|
||||
1. Scale API back to 0
|
||||
2. Find the next earlier backup
|
||||
3. Drop and recreate the database on the restored instance
|
||||
4. Repeat from step 3
|
||||
|
||||
**Never** point production at a known-bad restored DB hoping to fix it
|
||||
later — the API will write new data on top of the corruption and the
|
||||
salvage gets exponentially harder.
|
||||
|
||||
## Migration down-scripts
|
||||
|
||||
Every `.up.sql` in `internal/storage/migrations/` has a matching
|
||||
`.down.sql`. They're tested as part of CI and not exercised during
|
||||
normal restores (the up-only sequence in step 6 is the path used).
|
||||
They exist for:
|
||||
|
||||
- Drill scenarios where you want to "rewind" the schema
|
||||
- Emergency rollback of a bad shipped migration
|
||||
|
||||
Down-script integrity: run the `TestMigrationRoundtrip` integration
|
||||
test, which applies every migration up → down → up against a fresh
|
||||
container.
|
||||
|
||||
## Application config that supports restored DBs
|
||||
|
||||
`GG_DATABASE_URL` is the single source of truth — no hardcoded
|
||||
hostnames anywhere in the codebase. Verified by:
|
||||
|
||||
```bash
|
||||
grep -rE 'postgres://|host=.*5432' --include='*.go' . | grep -v _test.go | grep -v config.go
|
||||
# Expected: (empty)
|
||||
```
|
||||
|
||||
If anything surfaces from that grep, file a bug — it'll bite the next
|
||||
restore.
|
||||
|
||||
## Escalation
|
||||
|
||||
| Step fails | Who to call |
|
||||
|---|---|
|
||||
| Steps 1–4 | On-call infra lead |
|
||||
| Step 5 (`restore-verify`) | On-call backend lead + DBA |
|
||||
| Steps 7–8 (app won't start / smoke fails) | On-call backend lead |
|
||||
| Drill failure | File P1 ticket, link the drill log |
|
||||
|
||||
## Change log
|
||||
|
||||
| Date | Author | Change |
|
||||
|---|---|---|
|
||||
| 2026-05-16 | kay | initial version (Block G) |
|
||||
+29
-11
@@ -1,13 +1,13 @@
|
||||
<script setup lang="ts">
|
||||
const { host, clear } = useHost()
|
||||
const auth = useAuth()
|
||||
const route = useRoute()
|
||||
|
||||
// GitHub icon is a "marketing" affordance — only show it on the public landing
|
||||
// page. Inside the app it just clutters the chrome.
|
||||
const showGithub = computed(() => route.path === '/')
|
||||
|
||||
function logout() {
|
||||
clear()
|
||||
async function signOut() {
|
||||
await auth.logout()
|
||||
navigateTo('/')
|
||||
}
|
||||
</script>
|
||||
@@ -34,12 +34,17 @@ function logout() {
|
||||
<path d="M12 0C5.374 0 0 5.373 0 12c0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23A11.509 11.509 0 0112 5.803c1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576C20.566 21.797 24 17.3 24 12c0-6.627-5.373-12-12-12z"/>
|
||||
</svg>
|
||||
</a>
|
||||
<template v-if="host">
|
||||
<button class="transition hover:text-zinc-100" @click="logout">Sign out</button>
|
||||
</template>
|
||||
<template v-else>
|
||||
<NuxtLink to="/dashboard" class="transition hover:text-zinc-100">Sign in</NuxtLink>
|
||||
</template>
|
||||
<ClientOnly>
|
||||
<template v-if="auth.isAuthenticated.value">
|
||||
<NuxtLink to="/dashboard" class="transition hover:text-zinc-100">Dashboard</NuxtLink>
|
||||
<NuxtLink to="/dashboard/billing" class="transition hover:text-zinc-100">Billing</NuxtLink>
|
||||
<button class="transition hover:text-zinc-100" @click="signOut">Sign out</button>
|
||||
</template>
|
||||
<template v-else>
|
||||
<NuxtLink to="/login" class="transition hover:text-zinc-100">Sign in</NuxtLink>
|
||||
<NuxtLink to="/signup" class="btn-primary !px-3 !py-1.5 text-xs">Get started</NuxtLink>
|
||||
</template>
|
||||
</ClientOnly>
|
||||
</nav>
|
||||
</div>
|
||||
</header>
|
||||
@@ -48,9 +53,22 @@ function logout() {
|
||||
<NuxtPage />
|
||||
</main>
|
||||
|
||||
<!-- Global "plan limit reached" prompt — surfaces whenever any API
|
||||
call returns 402. Lives at app-root so every page benefits
|
||||
without per-page wiring. -->
|
||||
<UpgradeModal />
|
||||
|
||||
<!-- Privacy / terms onboarding gate. Auto-shows when the signed-in
|
||||
user hasn't accepted the current policies yet. No-op otherwise. -->
|
||||
<TermsGateModal />
|
||||
|
||||
<footer class="mt-16 border-t border-zinc-900">
|
||||
<div class="mx-auto max-w-6xl px-6 py-6 text-xs text-zinc-500">
|
||||
© 2025 GuestGuard — Hassle-free RSVPs for every occasion.
|
||||
<div class="mx-auto flex max-w-6xl flex-wrap items-center justify-between gap-3 px-6 py-6 text-xs text-zinc-500">
|
||||
<span>© 2025 GuestGuard — Hassle-free RSVPs for every occasion.</span>
|
||||
<span class="flex items-center gap-4">
|
||||
<NuxtLink to="/privacy" class="hover:text-zinc-300">Privacy</NuxtLink>
|
||||
<NuxtLink to="/terms" class="hover:text-zinc-300">Terms</NuxtLink>
|
||||
</span>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
<script setup lang="ts">
|
||||
interface ParsedRow {
|
||||
Name: string
|
||||
Email: string
|
||||
Phone: string
|
||||
PlusOnes: number
|
||||
}
|
||||
interface RowError {
|
||||
row: number
|
||||
reason: string
|
||||
}
|
||||
interface PreviewResponse {
|
||||
rows: ParsedRow[]
|
||||
errors?: RowError[]
|
||||
total_count: number
|
||||
}
|
||||
interface ImportResponse {
|
||||
added: number
|
||||
skipped: number
|
||||
skipped_emails?: string[]
|
||||
errors?: RowError[]
|
||||
total_count: number
|
||||
}
|
||||
|
||||
const props = defineProps<{ eventId: string }>()
|
||||
const emit = defineEmits<{ (e: 'imported'): void }>()
|
||||
|
||||
const config = useRuntimeConfig()
|
||||
const apiBase = config.public.apiBase as string
|
||||
|
||||
type Stage = 'idle' | 'preview' | 'committing' | 'done'
|
||||
const stage = ref<Stage>('idle')
|
||||
const fileName = ref('')
|
||||
const dragging = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const preview = ref<PreviewResponse | null>(null)
|
||||
const result = ref<ImportResponse | null>(null)
|
||||
let pendingFile: File | null = null
|
||||
|
||||
function reset() {
|
||||
stage.value = 'idle'
|
||||
fileName.value = ''
|
||||
preview.value = null
|
||||
result.value = null
|
||||
error.value = null
|
||||
pendingFile = null
|
||||
}
|
||||
|
||||
async function onFiles(files: FileList | File[] | null) {
|
||||
if (!files || !files.length) return
|
||||
const f = files[0]
|
||||
if (f.size > 1024 * 1024) {
|
||||
error.value = 'File is larger than 1MB.'
|
||||
return
|
||||
}
|
||||
fileName.value = f.name
|
||||
error.value = null
|
||||
pendingFile = f
|
||||
await runPreview()
|
||||
}
|
||||
|
||||
async function runPreview() {
|
||||
if (!pendingFile) return
|
||||
try {
|
||||
const fd = new FormData()
|
||||
fd.append('file', pendingFile)
|
||||
preview.value = await useApi<PreviewResponse>(
|
||||
`/events/${props.eventId}/guests/import/preview`,
|
||||
{ method: 'POST', body: fd },
|
||||
)
|
||||
stage.value = 'preview'
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Could not parse that CSV')
|
||||
stage.value = 'idle'
|
||||
pendingFile = null
|
||||
}
|
||||
}
|
||||
|
||||
async function commit() {
|
||||
if (!pendingFile) return
|
||||
stage.value = 'committing'
|
||||
error.value = null
|
||||
try {
|
||||
const fd = new FormData()
|
||||
fd.append('file', pendingFile)
|
||||
result.value = await useApi<ImportResponse>(
|
||||
`/events/${props.eventId}/guests/import`,
|
||||
{ method: 'POST', body: fd },
|
||||
)
|
||||
stage.value = 'done'
|
||||
emit('imported')
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Import failed')
|
||||
stage.value = 'preview'
|
||||
}
|
||||
}
|
||||
|
||||
function onDrop(e: DragEvent) {
|
||||
dragging.value = false
|
||||
onFiles(e.dataTransfer?.files ?? null)
|
||||
}
|
||||
|
||||
const templateUrl = computed(() => `${apiBase}/events/${props.eventId}/guests/import/template`)
|
||||
|
||||
async function downloadTemplate() {
|
||||
// The endpoint requires a Bearer header, which a plain <a download> can't
|
||||
// attach — fetch through useApi (which adds auth + handles refresh) then
|
||||
// synthesise an anchor click on the resulting blob.
|
||||
try {
|
||||
const res: any = await useApi(`/events/${props.eventId}/guests/import/template`, {
|
||||
method: 'GET',
|
||||
})
|
||||
// useApi auto-parses JSON; for CSV the response body is plain text.
|
||||
const text = typeof res === 'string' ? res : JSON.stringify(res)
|
||||
const blob = new Blob([text], { type: 'text/csv;charset=utf-8' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = 'guestguard-import-template.csv'
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Could not download template')
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<!-- No outer .card chrome: this component is now embedded inside a
|
||||
modal that supplies the surface. -->
|
||||
<div>
|
||||
<div class="mb-3 flex items-center justify-end">
|
||||
<button class="text-xs text-zinc-400 hover:text-zinc-200" @click="downloadTemplate">
|
||||
Download template
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Stage 1: drag-drop zone -->
|
||||
<div v-if="stage === 'idle'">
|
||||
<label
|
||||
class="flex cursor-pointer flex-col items-center justify-center rounded-lg border-2 border-dashed bg-zinc-900/50 p-6 text-center transition"
|
||||
:class="dragging ? 'border-brand-500 bg-brand-500/5' : 'border-zinc-700 hover:border-zinc-600'"
|
||||
@dragover.prevent="dragging = true"
|
||||
@dragleave.prevent="dragging = false"
|
||||
@drop.prevent="onDrop"
|
||||
>
|
||||
<span class="mb-1 text-sm text-zinc-200">Drop a CSV here or click to choose</span>
|
||||
<span class="text-xs text-zinc-500">Up to 5,000 rows · 1MB max</span>
|
||||
<input type="file" accept=".csv,text/csv" class="hidden" @change="onFiles(($event.target as HTMLInputElement).files)" />
|
||||
</label>
|
||||
<p class="mt-2 text-xs text-zinc-500">
|
||||
Required column: <code>name</code>. Optional: <code>email</code>, <code>phone</code>, <code>plus_ones</code>.
|
||||
</p>
|
||||
<p v-if="error" class="mt-3 text-sm text-red-400">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Stage 2: preview -->
|
||||
<div v-else-if="stage === 'preview' && preview">
|
||||
<p class="mb-2 text-sm">
|
||||
<span class="text-zinc-200">{{ fileName }}</span>
|
||||
<span class="text-zinc-500"> — {{ preview.rows.length }} valid row{{ preview.rows.length === 1 ? '' : 's' }}, {{ preview.errors?.length || 0 }} error{{ preview.errors?.length === 1 ? '' : 's' }}.</span>
|
||||
</p>
|
||||
|
||||
<div v-if="preview.errors && preview.errors.length" class="mb-3 max-h-32 overflow-auto rounded border border-amber-900/40 bg-amber-950/20 p-2 text-xs text-amber-200">
|
||||
<p class="mb-1 font-medium">Rows with problems (these will be skipped):</p>
|
||||
<ul class="space-y-0.5">
|
||||
<li v-for="(err, i) in preview.errors" :key="i">
|
||||
row {{ err.row }} — {{ err.reason }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div v-if="preview.rows.length" class="mb-3 max-h-64 overflow-auto rounded border border-zinc-800">
|
||||
<table class="w-full text-xs">
|
||||
<thead class="bg-zinc-900 text-zinc-400">
|
||||
<tr>
|
||||
<th class="px-2 py-1.5 text-left">Name</th>
|
||||
<th class="px-2 py-1.5 text-left">Email</th>
|
||||
<th class="px-2 py-1.5 text-left">Phone</th>
|
||||
<th class="px-2 py-1.5 text-right">+1s</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="divide-y divide-zinc-800">
|
||||
<tr v-for="(r, i) in preview.rows.slice(0, 100)" :key="i" class="text-zinc-200">
|
||||
<td class="px-2 py-1">{{ r.Name }}</td>
|
||||
<td class="px-2 py-1 text-zinc-400">{{ r.Email || '—' }}</td>
|
||||
<td class="px-2 py-1 text-zinc-400">{{ r.Phone || '—' }}</td>
|
||||
<td class="px-2 py-1 text-right tabular-nums">{{ r.PlusOnes }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<p v-if="preview.rows.length > 100" class="px-2 py-1 text-xs text-zinc-500">
|
||||
+ {{ preview.rows.length - 100 }} more not shown.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<button class="btn-primary" :disabled="!preview.rows.length || stage === 'committing'" @click="commit">
|
||||
Looks good — import {{ preview.rows.length }} guest{{ preview.rows.length === 1 ? '' : 's' }}
|
||||
</button>
|
||||
<button class="btn-ghost" @click="reset">Cancel</button>
|
||||
</div>
|
||||
<p v-if="error" class="mt-3 text-sm text-red-400">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Stage 3: committing -->
|
||||
<div v-else-if="stage === 'committing'" class="text-sm text-zinc-400">Importing…</div>
|
||||
|
||||
<!-- Stage 4: done -->
|
||||
<div v-else-if="stage === 'done' && result" class="text-sm">
|
||||
<p class="mb-1 font-medium text-brand-300">Imported {{ result.added }} guest{{ result.added === 1 ? '' : 's' }}.</p>
|
||||
<p v-if="result.skipped" class="text-zinc-400">Skipped {{ result.skipped }} duplicate{{ result.skipped === 1 ? '' : 's' }} (already on this event).</p>
|
||||
<p v-if="result.errors?.length" class="text-zinc-400">{{ result.errors.length }} row{{ result.errors.length === 1 ? '' : 's' }} had errors and were not imported.</p>
|
||||
<button class="btn-ghost mt-3" @click="reset">Import another file</button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
@@ -0,0 +1,240 @@
|
||||
<script setup lang="ts">
|
||||
// PhoneInput — country code picker + national digits input.
|
||||
//
|
||||
// Emits an E.164 string via v-model (e.g. "+233244123456"). Empty input
|
||||
// emits an empty string. The country list covers the ~50 most likely
|
||||
// origins for event guests; uncommon ones can still be typed by picking
|
||||
// the closest country and entering the full number — the live preview
|
||||
// shows the host what's being saved.
|
||||
//
|
||||
// Industry-standard UX: WhatsApp / Stripe / airline-booking pattern.
|
||||
|
||||
interface Country {
|
||||
code: string // ISO-3166 alpha-2
|
||||
name: string
|
||||
dialCode: string // includes leading "+"
|
||||
}
|
||||
|
||||
// Curated list — alphabetical-by-name, weighted toward common GuestGuard
|
||||
// audience (UK/EU + Africa + diaspora destinations). The host can always
|
||||
// type the full international number into another country's row if their
|
||||
// country isn't here, but most won't need to.
|
||||
const COUNTRIES: Country[] = [
|
||||
{ code: 'AR', name: 'Argentina', dialCode: '+54' },
|
||||
{ code: 'AU', name: 'Australia', dialCode: '+61' },
|
||||
{ code: 'AT', name: 'Austria', dialCode: '+43' },
|
||||
{ code: 'BE', name: 'Belgium', dialCode: '+32' },
|
||||
{ code: 'BR', name: 'Brazil', dialCode: '+55' },
|
||||
{ code: 'CM', name: 'Cameroon', dialCode: '+237' },
|
||||
{ code: 'CA', name: 'Canada', dialCode: '+1' },
|
||||
{ code: 'CL', name: 'Chile', dialCode: '+56' },
|
||||
{ code: 'CN', name: 'China', dialCode: '+86' },
|
||||
{ code: 'CI', name: "Côte d'Ivoire", dialCode: '+225' },
|
||||
{ code: 'DK', name: 'Denmark', dialCode: '+45' },
|
||||
{ code: 'EG', name: 'Egypt', dialCode: '+20' },
|
||||
{ code: 'ET', name: 'Ethiopia', dialCode: '+251' },
|
||||
{ code: 'FI', name: 'Finland', dialCode: '+358' },
|
||||
{ code: 'FR', name: 'France', dialCode: '+33' },
|
||||
{ code: 'DE', name: 'Germany', dialCode: '+49' },
|
||||
{ code: 'GH', name: 'Ghana', dialCode: '+233' },
|
||||
{ code: 'HK', name: 'Hong Kong', dialCode: '+852' },
|
||||
{ code: 'IN', name: 'India', dialCode: '+91' },
|
||||
{ code: 'ID', name: 'Indonesia', dialCode: '+62' },
|
||||
{ code: 'IE', name: 'Ireland', dialCode: '+353' },
|
||||
{ code: 'IL', name: 'Israel', dialCode: '+972' },
|
||||
{ code: 'IT', name: 'Italy', dialCode: '+39' },
|
||||
{ code: 'JP', name: 'Japan', dialCode: '+81' },
|
||||
{ code: 'KE', name: 'Kenya', dialCode: '+254' },
|
||||
{ code: 'MY', name: 'Malaysia', dialCode: '+60' },
|
||||
{ code: 'MX', name: 'Mexico', dialCode: '+52' },
|
||||
{ code: 'MA', name: 'Morocco', dialCode: '+212' },
|
||||
{ code: 'NL', name: 'Netherlands', dialCode: '+31' },
|
||||
{ code: 'NZ', name: 'New Zealand', dialCode: '+64' },
|
||||
{ code: 'NG', name: 'Nigeria', dialCode: '+234' },
|
||||
{ code: 'NO', name: 'Norway', dialCode: '+47' },
|
||||
{ code: 'PH', name: 'Philippines', dialCode: '+63' },
|
||||
{ code: 'PL', name: 'Poland', dialCode: '+48' },
|
||||
{ code: 'PT', name: 'Portugal', dialCode: '+351' },
|
||||
{ code: 'RW', name: 'Rwanda', dialCode: '+250' },
|
||||
{ code: 'SA', name: 'Saudi Arabia', dialCode: '+966' },
|
||||
{ code: 'SN', name: 'Senegal', dialCode: '+221' },
|
||||
{ code: 'SG', name: 'Singapore', dialCode: '+65' },
|
||||
{ code: 'ZA', name: 'South Africa', dialCode: '+27' },
|
||||
{ code: 'KR', name: 'South Korea', dialCode: '+82' },
|
||||
{ code: 'ES', name: 'Spain', dialCode: '+34' },
|
||||
{ code: 'SE', name: 'Sweden', dialCode: '+46' },
|
||||
{ code: 'CH', name: 'Switzerland', dialCode: '+41' },
|
||||
{ code: 'TW', name: 'Taiwan', dialCode: '+886' },
|
||||
{ code: 'TZ', name: 'Tanzania', dialCode: '+255' },
|
||||
{ code: 'TH', name: 'Thailand', dialCode: '+66' },
|
||||
{ code: 'TR', name: 'Turkey', dialCode: '+90' },
|
||||
{ code: 'AE', name: 'UAE', dialCode: '+971' },
|
||||
{ code: 'UG', name: 'Uganda', dialCode: '+256' },
|
||||
{ code: 'GB', name: 'United Kingdom', dialCode: '+44' },
|
||||
{ code: 'US', name: 'United States', dialCode: '+1' },
|
||||
]
|
||||
|
||||
// Longer dial codes first so "+1" doesn't shadow "+1...". Sorted once.
|
||||
const SORTED_BY_DIAL = [...COUNTRIES].sort((a, b) => b.dialCode.length - a.dialCode.length)
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: string
|
||||
/** ISO-3166 alpha-2 to use as default. Browser locale is consulted if omitted. */
|
||||
defaultCountry?: string
|
||||
}>()
|
||||
const emit = defineEmits<{ (e: 'update:modelValue', value: string): void }>()
|
||||
|
||||
function findByCode(code: string): Country | undefined {
|
||||
return COUNTRIES.find((c) => c.code === code.toUpperCase())
|
||||
}
|
||||
|
||||
function parsePhone(v: string): { country: Country | null; national: string } {
|
||||
if (!v) return { country: null, national: '' }
|
||||
const trimmed = v.trim().replace(/[\s\-()]/g, '')
|
||||
for (const c of SORTED_BY_DIAL) {
|
||||
if (trimmed.startsWith(c.dialCode)) {
|
||||
return { country: c, national: trimmed.slice(c.dialCode.length) }
|
||||
}
|
||||
}
|
||||
// Fall through: leading +XX didn't match anything (or no leading +).
|
||||
// Strip a leading + and a leading 0 from local-format numbers so the
|
||||
// host's national digits show up cleanly in the input.
|
||||
return { country: null, national: trimmed.replace(/^\+/, '').replace(/^0+/, '') }
|
||||
}
|
||||
|
||||
function detectDefault(): Country {
|
||||
if (props.defaultCountry) {
|
||||
const c = findByCode(props.defaultCountry)
|
||||
if (c) return c
|
||||
}
|
||||
if (typeof navigator !== 'undefined' && navigator.language) {
|
||||
const region = navigator.language.split('-')[1] || ''
|
||||
const c = findByCode(region)
|
||||
if (c) return c
|
||||
}
|
||||
return findByCode('GB')!
|
||||
}
|
||||
|
||||
const initial = parsePhone(props.modelValue)
|
||||
const country = ref<Country>(initial.country || detectDefault())
|
||||
const national = ref(initial.national)
|
||||
|
||||
// Digits the host typed, with leading 0s stripped (local-format helper).
|
||||
const nationalDigits = computed(() => national.value.replace(/\D/g, '').replace(/^0+/, ''))
|
||||
|
||||
// E.164 composed from current state. Empty when no digits — keeps the
|
||||
// stored value clean (don't save "+44" with no number).
|
||||
const composed = computed(() => {
|
||||
return nationalDigits.value ? `${country.value.dialCode}${nationalDigits.value}` : ''
|
||||
})
|
||||
|
||||
// Inline validation. We don't try to encode per-country length rules
|
||||
// (that's libphonenumber's job and overkill here) — instead we apply a
|
||||
// generous floor/ceiling on the national digit count. Catches obvious
|
||||
// typos without false-positives on shorter formats like Iceland (+354 7).
|
||||
type Validation = 'empty' | 'short' | 'long' | 'ok'
|
||||
const validation = computed<Validation>(() => {
|
||||
const n = nationalDigits.value.length
|
||||
if (n === 0) return 'empty'
|
||||
if (n < 6) return 'short'
|
||||
if (n > 13) return 'long'
|
||||
return 'ok'
|
||||
})
|
||||
|
||||
// Emit on user changes — guard reentrancy so an external prop update
|
||||
// doesn't bounce back into our own watcher.
|
||||
let emitting = false
|
||||
watch(composed, (v) => {
|
||||
emitting = true
|
||||
emit('update:modelValue', v)
|
||||
Promise.resolve().then(() => { emitting = false })
|
||||
})
|
||||
|
||||
// Re-sync from prop changes (e.g., parent resets the form, edit modal
|
||||
// reopens for a different guest). Skip when our own emit caused it.
|
||||
watch(() => props.modelValue, (v) => {
|
||||
if (emitting) return
|
||||
const p = parsePhone(v)
|
||||
if (p.country) country.value = p.country
|
||||
national.value = p.national
|
||||
})
|
||||
|
||||
// Smart input: if the host pastes (or types) a full E.164 with leading
|
||||
// "+" into the national field, extract the country code and split it
|
||||
// into the picker + the national digits. Rescues the common mistake of
|
||||
// pasting "+233244123456" into the digits field instead of using the
|
||||
// picker — without this, the value would get double-prefixed.
|
||||
function onNationalInput(e: Event) {
|
||||
const v = (e.target as HTMLInputElement).value
|
||||
if (v.startsWith('+')) {
|
||||
const parsed = parsePhone(v)
|
||||
if (parsed.country) {
|
||||
country.value = parsed.country
|
||||
national.value = parsed.national
|
||||
return
|
||||
}
|
||||
}
|
||||
national.value = v
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div>
|
||||
<div class="flex gap-2">
|
||||
<select
|
||||
v-model="country"
|
||||
class="input w-28 shrink-0 cursor-pointer"
|
||||
aria-label="Country code"
|
||||
>
|
||||
<option v-for="c in COUNTRIES" :key="c.code" :value="c">
|
||||
{{ c.dialCode }} {{ c.code }}
|
||||
</option>
|
||||
</select>
|
||||
<input
|
||||
:value="national"
|
||||
@input="onNationalInput"
|
||||
type="tel"
|
||||
inputmode="tel"
|
||||
autocomplete="tel-national"
|
||||
class="input flex-1"
|
||||
:class="{
|
||||
'border-amber-700/60 focus:border-amber-500 focus:ring-amber-500': validation === 'short' || validation === 'long',
|
||||
}"
|
||||
placeholder="Phone number"
|
||||
:aria-invalid="validation === 'short' || validation === 'long' || undefined"
|
||||
/>
|
||||
</div>
|
||||
<!-- Live feedback. Different message per validation state:
|
||||
empty → optional hint
|
||||
short → amber warning, hostsees it before pressing Save
|
||||
long → amber warning
|
||||
ok → green confirmation with the canonical E.164 visible -->
|
||||
<p class="mt-1 flex items-center gap-1 text-xs">
|
||||
<template v-if="validation === 'empty'">
|
||||
<span class="text-zinc-500">
|
||||
Optional — include the country code so guests on any network can be reached.
|
||||
</span>
|
||||
</template>
|
||||
<template v-else-if="validation === 'short'">
|
||||
<svg class="h-3.5 w-3.5 shrink-0 text-amber-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
|
||||
<path fill-rule="evenodd" d="M8.485 3.495c.673-1.167 2.357-1.167 3.03 0l6.28 10.875c.673 1.167-.17 2.625-1.516 2.625H3.72c-1.347 0-2.189-1.458-1.515-2.625L8.485 3.495zM10 8a1 1 0 01.993.883L11 9v3a1 1 0 01-1.993.117L9 12V9a1 1 0 011-1zm0 6a1 1 0 110 2 1 1 0 010-2z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
<span class="text-amber-300">Looks too short — make sure you've entered all the digits.</span>
|
||||
</template>
|
||||
<template v-else-if="validation === 'long'">
|
||||
<svg class="h-3.5 w-3.5 shrink-0 text-amber-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
|
||||
<path fill-rule="evenodd" d="M8.485 3.495c.673-1.167 2.357-1.167 3.03 0l6.28 10.875c.673 1.167-.17 2.625-1.516 2.625H3.72c-1.347 0-2.189-1.458-1.515-2.625L8.485 3.495zM10 8a1 1 0 01.993.883L11 9v3a1 1 0 01-1.993.117L9 12V9a1 1 0 011-1zm0 6a1 1 0 110 2 1 1 0 010-2z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
<span class="text-amber-300">Looks too long — check for extra digits.</span>
|
||||
</template>
|
||||
<template v-else>
|
||||
<svg class="h-3.5 w-3.5 shrink-0 text-brand-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
|
||||
<path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
<span class="text-zinc-500">
|
||||
Saved as <span class="font-mono text-zinc-300">{{ composed }}</span>
|
||||
</span>
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
</template>
|
||||
@@ -0,0 +1,80 @@
|
||||
<script setup lang="ts">
|
||||
// Shown the first time a host signs into the dashboard after T&C
|
||||
// enforcement is rolled out. Existing accounts created before this
|
||||
// feature don't have terms_accepted_at set — they're re-prompted once
|
||||
// here, then they're set going forward.
|
||||
//
|
||||
// Lives in app.vue so any authed page can be the host's first stop.
|
||||
|
||||
const auth = useAuth()
|
||||
const accepted = ref(false)
|
||||
const submitting = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
// Only fires when the user is signed in AND we know they haven't
|
||||
// accepted. /me payload includes terms_accepted_at; if absent → prompt.
|
||||
const needsAcceptance = computed(() => {
|
||||
const u = auth.user.value
|
||||
return !!u && !u.terms_accepted_at && !u.privacy_policy_accepted_at
|
||||
})
|
||||
|
||||
async function accept() {
|
||||
if (!accepted.value) return
|
||||
submitting.value = true
|
||||
error.value = null
|
||||
try {
|
||||
await useApi('/me/accept-terms', { method: 'POST' })
|
||||
// Refresh auth state so the user object reflects the new fields.
|
||||
await auth.refresh()
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Could not record acceptance')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="needsAcceptance"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm"
|
||||
>
|
||||
<div
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
aria-labelledby="terms-title"
|
||||
class="w-full max-w-md rounded-lg border border-zinc-800 bg-zinc-900 p-5 shadow-2xl"
|
||||
>
|
||||
<h3 id="terms-title" class="mb-1 text-base font-semibold">One quick thing</h3>
|
||||
<p class="mb-4 text-sm text-zinc-400">
|
||||
We've updated how GuestGuard handles your data. Before you carry on,
|
||||
please confirm you've read and agree to the current policies.
|
||||
</p>
|
||||
|
||||
<label class="flex cursor-pointer items-start gap-2 text-sm text-zinc-200">
|
||||
<input
|
||||
v-model="accepted"
|
||||
type="checkbox"
|
||||
class="mt-0.5 h-4 w-4 cursor-pointer accent-brand-500"
|
||||
/>
|
||||
<span>
|
||||
I agree to GuestGuard's
|
||||
<NuxtLink to="/terms" target="_blank" class="text-brand-400 hover:text-brand-300">Terms of Service</NuxtLink>
|
||||
and
|
||||
<NuxtLink to="/privacy" target="_blank" class="text-brand-400 hover:text-brand-300">Privacy Policy</NuxtLink>.
|
||||
</span>
|
||||
</label>
|
||||
|
||||
<button
|
||||
class="btn-primary mt-4 w-full disabled:opacity-50"
|
||||
:disabled="!accepted || submitting"
|
||||
@click="accept"
|
||||
>
|
||||
{{ submitting ? 'Saving…' : 'Continue' }}
|
||||
</button>
|
||||
<p v-if="error" class="mt-3 text-sm text-red-400">{{ error }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -0,0 +1,131 @@
|
||||
<script setup lang="ts">
|
||||
// Global "plan limit reached" modal. Bound to the upgrade-prompt state
|
||||
// in useBilling, which is populated by the 402 interceptor in useApi.
|
||||
// Lives in app.vue so any page that fires an API call benefits from it
|
||||
// without per-page wiring.
|
||||
|
||||
const billing = useBilling()
|
||||
const router = useRouter()
|
||||
const upgrading = ref<'pro' | 'business' | null>(null)
|
||||
|
||||
const reasonText = computed(() => {
|
||||
const p = billing.prompt.value
|
||||
if (!p) return ''
|
||||
if (p.reason === 'events_per_month') {
|
||||
return `You've used ${p.used} of ${p.limit} events this month on the ${labelTier(p.tier)} plan.`
|
||||
}
|
||||
if (p.reason === 'guests_per_event') {
|
||||
return `This event already has ${p.used} of ${p.limit} guests allowed on the ${labelTier(p.tier)} plan.`
|
||||
}
|
||||
return p.error
|
||||
})
|
||||
|
||||
function labelTier(t: string): string {
|
||||
return t.charAt(0).toUpperCase() + t.slice(1)
|
||||
}
|
||||
|
||||
async function quickUpgrade(tier: 'pro' | 'business') {
|
||||
upgrading.value = tier
|
||||
try {
|
||||
await billing.startCheckout(tier)
|
||||
// startCheckout navigates the page — nothing else to do.
|
||||
} catch {
|
||||
// Fallback: send the user to the billing page so they see the error
|
||||
// in context.
|
||||
await router.push('/dashboard/billing')
|
||||
billing.dismissUpgradePrompt()
|
||||
} finally {
|
||||
upgrading.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function viewPlans() {
|
||||
billing.dismissUpgradePrompt()
|
||||
router.push('/dashboard/billing')
|
||||
}
|
||||
|
||||
// Esc closes the modal — keeps interaction symmetrical with all the
|
||||
// other dialogs across the app.
|
||||
function onKeydown(e: KeyboardEvent) {
|
||||
if (e.key === 'Escape' && billing.prompt.value) {
|
||||
billing.dismissUpgradePrompt()
|
||||
}
|
||||
}
|
||||
if (import.meta.client) {
|
||||
onMounted(() => window.addEventListener('keydown', onKeydown))
|
||||
onUnmounted(() => window.removeEventListener('keydown', onKeydown))
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="billing.prompt.value"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/60 p-4 backdrop-blur-sm"
|
||||
@click.self="billing.dismissUpgradePrompt()"
|
||||
>
|
||||
<div
|
||||
role="alertdialog"
|
||||
aria-modal="true"
|
||||
aria-labelledby="upgrade-title"
|
||||
aria-describedby="upgrade-desc"
|
||||
class="w-full max-w-md rounded-lg border border-zinc-800 bg-zinc-900 p-5 shadow-2xl"
|
||||
>
|
||||
<div class="mb-3 flex items-start gap-3">
|
||||
<div class="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-brand-500/15">
|
||||
<svg class="h-4 w-4 text-brand-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
|
||||
<path fill-rule="evenodd" d="M5 9V7a5 5 0 0110 0v2a2 2 0 012 2v5a2 2 0 01-2 2H5a2 2 0 01-2-2v-5a2 2 0 012-2zm8-2v2H7V7a3 3 0 016 0z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
</div>
|
||||
<div class="min-w-0">
|
||||
<h3 id="upgrade-title" class="text-base font-semibold text-zinc-100">Plan limit reached</h3>
|
||||
<p id="upgrade-desc" class="mt-1 text-sm text-zinc-400">{{ reasonText }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p class="mb-4 text-sm text-zinc-300">Pick a plan to keep going — your work isn't lost.</p>
|
||||
|
||||
<div class="space-y-2">
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between rounded-md border border-brand-700/60 bg-brand-500/10 px-3 py-3 text-left transition hover:bg-brand-500/15 disabled:opacity-50"
|
||||
:disabled="upgrading !== null"
|
||||
@click="quickUpgrade('pro')"
|
||||
>
|
||||
<span>
|
||||
<span class="block text-sm font-medium text-zinc-100">Upgrade to Pro</span>
|
||||
<span class="block text-xs text-zinc-500">$49 / month · 10 events · 1,000 guests per event</span>
|
||||
</span>
|
||||
<span class="text-xs text-zinc-400">{{ upgrading === 'pro' ? 'Opening checkout…' : '→' }}</span>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between rounded-md border border-zinc-700 bg-zinc-950 px-3 py-3 text-left transition hover:border-zinc-500 hover:bg-zinc-900 disabled:opacity-50"
|
||||
:disabled="upgrading !== null"
|
||||
@click="quickUpgrade('business')"
|
||||
>
|
||||
<span>
|
||||
<span class="block text-sm font-medium text-zinc-100">Upgrade to Business</span>
|
||||
<span class="block text-xs text-zinc-500">$199 / month · Unlimited events · 5,000 guests per event</span>
|
||||
</span>
|
||||
<span class="text-xs text-zinc-400">{{ upgrading === 'business' ? 'Opening checkout…' : '→' }}</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="mt-4 flex items-center justify-between text-xs">
|
||||
<button
|
||||
type="button"
|
||||
class="text-zinc-400 hover:text-zinc-200"
|
||||
@click="viewPlans"
|
||||
>Compare all plans</button>
|
||||
<button
|
||||
type="button"
|
||||
class="text-zinc-500 hover:text-zinc-300"
|
||||
@click="billing.dismissUpgradePrompt()"
|
||||
>Maybe later</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -1,16 +1,57 @@
|
||||
// Typed wrapper around $fetch with the configured API base.
|
||||
// Usage: const events = await useApi<EventList>('/events')
|
||||
//
|
||||
// Adds `Authorization: Bearer <access_token>` when the caller is signed in,
|
||||
// and on a 401 transparently asks `/auth/refresh` for a new token and retries
|
||||
// once. Failed refresh clears local auth state — pages can rely on the
|
||||
// returned error to redirect to /login.
|
||||
|
||||
export async function useApi<T = unknown>(
|
||||
path: string,
|
||||
opts: { method?: string; body?: unknown; query?: Record<string, unknown> } = {},
|
||||
): Promise<T> {
|
||||
const config = useRuntimeConfig()
|
||||
const base = config.public.apiBase as string
|
||||
return await $fetch<T>(path, {
|
||||
baseURL: base,
|
||||
method: (opts.method ?? 'GET') as any,
|
||||
body: opts.body,
|
||||
query: opts.query,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
const auth = useAuth()
|
||||
|
||||
const request = async (token: string | null): Promise<T> => {
|
||||
const headers: Record<string, string> = {}
|
||||
// Let the browser set Content-Type (with the multipart boundary) when
|
||||
// the body is FormData / Blob; otherwise default to JSON.
|
||||
const isMultipart = typeof FormData !== 'undefined' && opts.body instanceof FormData
|
||||
if (!isMultipart) headers['Content-Type'] = 'application/json'
|
||||
if (token) headers.Authorization = `Bearer ${token}`
|
||||
return await $fetch<T>(path, {
|
||||
baseURL: base,
|
||||
method: (opts.method ?? 'GET') as any,
|
||||
body: opts.body as any,
|
||||
query: opts.query,
|
||||
headers,
|
||||
credentials: 'include',
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
return await request(auth.liveAccessToken())
|
||||
} catch (err: any) {
|
||||
const status = err?.response?.status ?? err?.statusCode
|
||||
|
||||
// 402 Payment Required — plan limit hit. Surface the backend's
|
||||
// upgrade payload on a global state slot; the UpgradeModal in
|
||||
// app.vue reads it and prompts the host to upgrade. We still
|
||||
// rethrow so the caller can stop its own UI flow if it wants.
|
||||
if (status === 402) {
|
||||
const data = err?.data
|
||||
if (data && data.upgrade_url) {
|
||||
useBilling().showUpgradePrompt(data)
|
||||
}
|
||||
throw err
|
||||
}
|
||||
|
||||
if (status !== 401) throw err
|
||||
// /auth/* endpoints set the cookie themselves — never retry-refresh them.
|
||||
if (path.startsWith('/auth/')) throw err
|
||||
const refreshed = await auth.refresh()
|
||||
if (!refreshed) throw err
|
||||
return await request(auth.liveAccessToken())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
// Auth state for the host-facing app.
|
||||
//
|
||||
// The access token lives only in memory (useState — Nuxt's SSR-safe wrapper).
|
||||
// The refresh token lives in an HttpOnly cookie set by the API at
|
||||
// `/auth/refresh` scope, so JavaScript here can never read it. On a hard
|
||||
// reload we lose the access token but the cookie survives, so `bootstrap()`
|
||||
// calls `/auth/refresh` to mint a fresh access token + reload the user.
|
||||
|
||||
interface AuthUser {
|
||||
id: string
|
||||
email: string
|
||||
name: string
|
||||
email_verified: boolean
|
||||
}
|
||||
|
||||
interface AuthSuccess {
|
||||
access_token: string
|
||||
expires_at: string
|
||||
user: AuthUser
|
||||
}
|
||||
|
||||
interface AuthState {
|
||||
user: AuthUser | null
|
||||
accessToken: string | null
|
||||
expiresAt: number | null // unix ms
|
||||
bootstrapped: boolean
|
||||
}
|
||||
|
||||
function emptyState(): AuthState {
|
||||
return { user: null, accessToken: null, expiresAt: null, bootstrapped: false }
|
||||
}
|
||||
|
||||
function apiBase(): string {
|
||||
return useRuntimeConfig().public.apiBase as string
|
||||
}
|
||||
|
||||
async function postJSON<T>(path: string, body?: unknown): Promise<T> {
|
||||
return await $fetch<T>(path, {
|
||||
baseURL: apiBase(),
|
||||
method: 'POST',
|
||||
body,
|
||||
credentials: 'include',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
}
|
||||
|
||||
export function useAuth() {
|
||||
const state = useState<AuthState>('gg-auth', emptyState)
|
||||
|
||||
function setSession(s: AuthSuccess) {
|
||||
state.value = {
|
||||
user: s.user,
|
||||
accessToken: s.access_token,
|
||||
expiresAt: Date.parse(s.expires_at) || (Date.now() + 14 * 60 * 1000),
|
||||
bootstrapped: true,
|
||||
}
|
||||
}
|
||||
|
||||
function clearSession() {
|
||||
state.value = { ...emptyState(), bootstrapped: true }
|
||||
}
|
||||
|
||||
async function signup(email: string, name: string, password: string, acceptTerms = false) {
|
||||
return await postJSON<{ status: string }>('/auth/signup', {
|
||||
email, name, password,
|
||||
accept_terms: acceptTerms,
|
||||
})
|
||||
}
|
||||
|
||||
async function login(email: string, password: string) {
|
||||
const s = await postJSON<AuthSuccess>('/auth/login', { email, password })
|
||||
setSession(s)
|
||||
return s
|
||||
}
|
||||
|
||||
async function refresh(): Promise<boolean> {
|
||||
try {
|
||||
const s = await postJSON<AuthSuccess>('/auth/refresh')
|
||||
setSession(s)
|
||||
return true
|
||||
} catch {
|
||||
clearSession()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
async function logout() {
|
||||
try {
|
||||
await postJSON<void>('/auth/logout')
|
||||
} catch {
|
||||
// Best-effort — clear local state regardless.
|
||||
}
|
||||
clearSession()
|
||||
}
|
||||
|
||||
async function verifyEmail(token: string) {
|
||||
return await postJSON<{ status: string }>('/auth/verify-email', { token })
|
||||
}
|
||||
|
||||
async function forgotPassword(email: string) {
|
||||
return await postJSON<{ status: string }>('/auth/forgot-password', { email })
|
||||
}
|
||||
|
||||
async function resetPassword(token: string, newPassword: string) {
|
||||
return await postJSON<{ status: string }>('/auth/reset-password', {
|
||||
token,
|
||||
new_password: newPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// Call on app entry / route guards. Returns true if the caller has a valid
|
||||
// session by the time it resolves.
|
||||
async function bootstrap(): Promise<boolean> {
|
||||
if (!import.meta.client) return false
|
||||
if (state.value.bootstrapped && state.value.user) return true
|
||||
if (state.value.bootstrapped && !state.value.user) return false
|
||||
return await refresh()
|
||||
}
|
||||
|
||||
// Hint to useApi: returns the current token if not yet expired.
|
||||
function liveAccessToken(): string | null {
|
||||
if (!state.value.accessToken || !state.value.expiresAt) return null
|
||||
// 5s skew to avoid sending a just-expired token.
|
||||
if (Date.now() + 5000 >= state.value.expiresAt) return null
|
||||
return state.value.accessToken
|
||||
}
|
||||
|
||||
const isAuthenticated = computed(() => !!state.value.user)
|
||||
const user = computed(() => state.value.user)
|
||||
const bootstrapped = computed(() => state.value.bootstrapped)
|
||||
|
||||
return {
|
||||
user,
|
||||
isAuthenticated,
|
||||
bootstrapped,
|
||||
signup,
|
||||
login,
|
||||
refresh,
|
||||
logout,
|
||||
verifyEmail,
|
||||
forgotPassword,
|
||||
resetPassword,
|
||||
bootstrap,
|
||||
liveAccessToken,
|
||||
clearSession,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
// Billing client: fetches subscription status, kicks off checkout/portal
|
||||
// flows, and owns the global "upgrade required" prompt shown by the
|
||||
// 402 interceptor in useApi. State is shared across components via
|
||||
// useState so the prompt can be triggered from any handler.
|
||||
|
||||
export interface BillingStatus {
|
||||
tier: 'free' | 'pro' | 'business'
|
||||
status: string
|
||||
current_period_end?: string
|
||||
cancel_at_period_end: boolean
|
||||
limits: {
|
||||
events_per_month: number
|
||||
guests_per_event: number
|
||||
}
|
||||
usage: {
|
||||
events_this_month: number
|
||||
}
|
||||
portal_available: boolean
|
||||
}
|
||||
|
||||
// UpgradePrompt mirrors the 402 body the backend returns when a limit is
|
||||
// hit. Shown globally as a modal until dismissed or acted upon.
|
||||
export interface UpgradePrompt {
|
||||
error: string
|
||||
reason: string
|
||||
tier: string
|
||||
used: number
|
||||
limit: number
|
||||
upgrade_url: string
|
||||
}
|
||||
|
||||
const FREE_DEFAULT: BillingStatus = {
|
||||
tier: 'free',
|
||||
status: 'active',
|
||||
cancel_at_period_end: false,
|
||||
limits: { events_per_month: 1, guests_per_event: 50 },
|
||||
usage: { events_this_month: 0 },
|
||||
portal_available: false,
|
||||
}
|
||||
|
||||
export function useBilling() {
|
||||
const status = useState<BillingStatus | null>('gg-billing-status', () => null)
|
||||
const loading = useState<boolean>('gg-billing-loading', () => false)
|
||||
const prompt = useState<UpgradePrompt | null>('gg-upgrade-prompt', () => null)
|
||||
|
||||
async function fetchStatus(): Promise<BillingStatus> {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await useApi<BillingStatus>('/billing/status')
|
||||
status.value = res
|
||||
return res
|
||||
} catch (e: any) {
|
||||
// 401/refresh edge → caller redirected to /login by useApi. If the
|
||||
// backend has billing wired but the response is malformed, fall
|
||||
// back to free defaults so the page renders something usable.
|
||||
status.value = FREE_DEFAULT
|
||||
return FREE_DEFAULT
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function startCheckout(tier: 'pro' | 'business'): Promise<void> {
|
||||
const res = await useApi<{ url: string }>('/billing/checkout-session', {
|
||||
method: 'POST',
|
||||
body: { tier },
|
||||
})
|
||||
if (import.meta.client) window.location.href = res.url
|
||||
}
|
||||
|
||||
async function openPortal(): Promise<void> {
|
||||
const res = await useApi<{ url: string }>('/billing/portal', { method: 'POST' })
|
||||
if (import.meta.client) window.location.href = res.url
|
||||
}
|
||||
|
||||
function showUpgradePrompt(info: UpgradePrompt) {
|
||||
prompt.value = info
|
||||
}
|
||||
function dismissUpgradePrompt() {
|
||||
prompt.value = null
|
||||
}
|
||||
|
||||
return {
|
||||
status,
|
||||
loading,
|
||||
prompt,
|
||||
fetchStatus,
|
||||
startCheckout,
|
||||
openPortal,
|
||||
showUpgradePrompt,
|
||||
dismissUpgradePrompt,
|
||||
}
|
||||
}
|
||||
|
||||
// Static pricing copy. Keep in sync with internal/billing/tiers.go.
|
||||
// One source of truth for the marketing-page-style cards in
|
||||
// /dashboard/billing and the UpgradeModal.
|
||||
export interface TierCard {
|
||||
id: 'free' | 'pro' | 'business'
|
||||
name: string
|
||||
price: string
|
||||
priceSubtitle: string
|
||||
tagline: string
|
||||
features: string[]
|
||||
highlight?: boolean
|
||||
}
|
||||
|
||||
export const TIER_CARDS: TierCard[] = [
|
||||
{
|
||||
id: 'free',
|
||||
name: 'Free',
|
||||
price: '$0',
|
||||
priceSubtitle: 'forever',
|
||||
tagline: 'Try GuestGuard with a single event.',
|
||||
features: [
|
||||
'1 event per month',
|
||||
'Up to 50 guests per event',
|
||||
'Branded email invitations',
|
||||
'Real-time RSVP dashboard',
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 'pro',
|
||||
name: 'Pro',
|
||||
price: '$49',
|
||||
priceSubtitle: 'per month',
|
||||
tagline: 'For active hosts running several events.',
|
||||
features: [
|
||||
'10 events per month',
|
||||
'Up to 1,000 guests per event',
|
||||
'WhatsApp + email invitations',
|
||||
'CSV import + bulk send',
|
||||
'Priority email support',
|
||||
],
|
||||
highlight: true,
|
||||
},
|
||||
{
|
||||
id: 'business',
|
||||
name: 'Business',
|
||||
price: '$199',
|
||||
priceSubtitle: 'per month',
|
||||
tagline: 'For agencies and corporate events teams.',
|
||||
features: [
|
||||
'Unlimited events',
|
||||
'Up to 5,000 guests per event',
|
||||
'Everything in Pro',
|
||||
'Signed DPA on request',
|
||||
'SLA with response targets',
|
||||
],
|
||||
},
|
||||
]
|
||||
@@ -0,0 +1,28 @@
|
||||
// Friendly error messages for API failures, with first-class handling of
|
||||
// 429 (rate-limited) and 403 account-lockout responses.
|
||||
export function useErrMessage(e: any, fallback = 'Something went wrong'): string {
|
||||
const status: number | undefined = e?.response?.status ?? e?.statusCode
|
||||
const data = e?.data
|
||||
const serverMsg: string | undefined = data?.error
|
||||
|
||||
if (status === 429) {
|
||||
const retry: number | undefined = data?.retry_after
|
||||
if (typeof retry === 'number' && retry > 0) {
|
||||
return `You're going too fast — try again in ${formatSeconds(retry)}.`
|
||||
}
|
||||
return "You're going too fast — please try again in a moment."
|
||||
}
|
||||
|
||||
if (status === 403 && typeof serverMsg === 'string' && serverMsg.toLowerCase().includes('locked')) {
|
||||
return 'Your account is locked after too many failed sign-in attempts. Reset your password to unlock it.'
|
||||
}
|
||||
|
||||
if (serverMsg) return serverMsg
|
||||
return e?.message || fallback
|
||||
}
|
||||
|
||||
function formatSeconds(s: number): string {
|
||||
if (s < 60) return `${s} second${s === 1 ? '' : 's'}`
|
||||
const m = Math.ceil(s / 60)
|
||||
return `${m} minute${m === 1 ? '' : 's'}`
|
||||
}
|
||||
@@ -1,7 +1,12 @@
|
||||
// Subscribes to /ws/events/:id and emits per-message callbacks.
|
||||
//
|
||||
// Auto-reconnects with exponential backoff up to 30s. Returns a cleanup
|
||||
// fn the caller invokes (e.g. inside onUnmounted).
|
||||
// Authenticates via short-lived ticket: before each connect we POST
|
||||
// /auth/ws-ticket (bearer-authed) to mint a one-shot ticket, then pass it
|
||||
// on the WS handshake as `?ticket=…`. Tickets expire ~60s after mint, so we
|
||||
// always mint fresh — even on reconnects.
|
||||
//
|
||||
// Auto-reconnects with exponential backoff up to 30s. Returns a cleanup fn
|
||||
// the caller invokes (e.g. inside onUnmounted).
|
||||
|
||||
interface WSMessage {
|
||||
type: string
|
||||
@@ -10,21 +15,46 @@ interface WSMessage {
|
||||
timestamp: string
|
||||
}
|
||||
|
||||
interface WSTicket {
|
||||
ticket: string
|
||||
expires_at: string
|
||||
}
|
||||
|
||||
export function useEventWS(eventId: string, onMessage: (msg: WSMessage) => void) {
|
||||
if (import.meta.server) return () => {}
|
||||
|
||||
const config = useRuntimeConfig()
|
||||
const base = (config.public.wsBase as string) || ''
|
||||
const url = `${base}/ws/events/${eventId}`
|
||||
const wsBase = (config.public.wsBase as string) || ''
|
||||
|
||||
let ws: WebSocket | null = null
|
||||
let attempt = 0
|
||||
let stopped = false
|
||||
let reconnectTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
function connect() {
|
||||
async function mintTicket(): Promise<string | null> {
|
||||
try {
|
||||
const t = await useApi<WSTicket>('/auth/ws-ticket', {
|
||||
method: 'POST',
|
||||
body: { event_id: eventId },
|
||||
})
|
||||
return t.ticket
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
async function connect() {
|
||||
if (stopped) return
|
||||
ws = new WebSocket(url)
|
||||
const ticket = await mintTicket()
|
||||
if (stopped) return
|
||||
if (!ticket) {
|
||||
// Couldn't get a ticket (likely 401 — session expired). Back off and
|
||||
// retry; useApi will have already attempted refresh on its own.
|
||||
const backoff = Math.min(30_000, 500 * Math.pow(2, attempt++))
|
||||
reconnectTimer = setTimeout(connect, backoff)
|
||||
return
|
||||
}
|
||||
ws = new WebSocket(`${wsBase}/ws/events/${eventId}?ticket=${encodeURIComponent(ticket)}`)
|
||||
|
||||
ws.onopen = () => {
|
||||
attempt = 0
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
// Demo-grade host bootstrap. Real auth would replace this entirely; for now
|
||||
// we upsert by email and stash the host id in localStorage.
|
||||
interface User {
|
||||
id: string
|
||||
email: string
|
||||
name: string
|
||||
}
|
||||
|
||||
const STORAGE_KEY = 'gg.host'
|
||||
|
||||
export function useHost() {
|
||||
const host = useState<User | null>('gg-host', () => null)
|
||||
|
||||
if (import.meta.client && !host.value) {
|
||||
const raw = window.localStorage.getItem(STORAGE_KEY)
|
||||
if (raw) {
|
||||
try {
|
||||
host.value = JSON.parse(raw)
|
||||
} catch {
|
||||
window.localStorage.removeItem(STORAGE_KEY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function bootstrap(email: string, name: string) {
|
||||
const u = await useApi<User>('/users', {
|
||||
method: 'POST',
|
||||
body: { email, name },
|
||||
})
|
||||
host.value = u
|
||||
if (import.meta.client) {
|
||||
window.localStorage.setItem(STORAGE_KEY, JSON.stringify(u))
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
function clear() {
|
||||
host.value = null
|
||||
if (import.meta.client) window.localStorage.removeItem(STORAGE_KEY)
|
||||
}
|
||||
|
||||
return { host, bootstrap, clear }
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
// Route guard: bootstraps the session (server -> client cookie roundtrip,
|
||||
// then optional /auth/refresh), and redirects to /login if no session.
|
||||
//
|
||||
// Skip on SSR: auth state lives entirely in the browser. Dashboard pages
|
||||
// render their own client-side loading state while we resolve.
|
||||
|
||||
export default defineNuxtRouteMiddleware(async (to) => {
|
||||
if (import.meta.server) return
|
||||
const auth = useAuth()
|
||||
const ok = await auth.bootstrap()
|
||||
if (!ok) {
|
||||
return navigateTo({ path: '/login', query: { redirect: to.fullPath } })
|
||||
}
|
||||
})
|
||||
Generated
+15
-4
@@ -5092,12 +5092,14 @@
|
||||
}
|
||||
},
|
||||
"node_modules/commander": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz",
|
||||
"integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==",
|
||||
"version": "13.1.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-13.1.0.tgz",
|
||||
"integrity": "sha512-/rFeCpNJQbhSZjGVwO9RFV3xPqbnERS8MmIQzCtD/zl6gpJuV/bMLuN92oG3F7d8oDEHHRrujSXNUr8fpjntKw==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/common-path-prefix": {
|
||||
@@ -8500,6 +8502,15 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/lambda-local/node_modules/commander": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz",
|
||||
"integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
}
|
||||
},
|
||||
"node_modules/lambda-local/node_modules/dotenv": {
|
||||
"version": "16.6.1",
|
||||
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.6.1.tgz",
|
||||
|
||||
@@ -0,0 +1,393 @@
|
||||
<script setup lang="ts">
|
||||
definePageMeta({ middleware: ['auth'] })
|
||||
|
||||
const route = useRoute()
|
||||
const router = useRouter()
|
||||
const billing = useBilling()
|
||||
const auth = useAuth()
|
||||
const config = useRuntimeConfig()
|
||||
const switching = ref<'pro' | 'business' | null>(null)
|
||||
const portalLoading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const toast = ref<string | null>(null)
|
||||
let toastTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
// Your data
|
||||
const exporting = ref(false)
|
||||
const deleteConfirmOpen = ref(false)
|
||||
const deleteConfirmation = ref('')
|
||||
const deleting = ref(false)
|
||||
const deleteError = ref<string | null>(null)
|
||||
|
||||
async function exportData() {
|
||||
exporting.value = true
|
||||
try {
|
||||
const apiBase = config.public.apiBase as string
|
||||
const token = auth.liveAccessToken()
|
||||
// Plain fetch (not useApi) so the response is treated as a download.
|
||||
const res = await fetch(`${apiBase}/me/data-export`, {
|
||||
headers: token ? { Authorization: `Bearer ${token}` } : {},
|
||||
credentials: 'include',
|
||||
})
|
||||
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
||||
const blob = await res.blob()
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = 'guestguard-data-export.json'
|
||||
a.click()
|
||||
URL.revokeObjectURL(url)
|
||||
showToast('Export downloaded.')
|
||||
} catch (e: any) {
|
||||
showToast(useErrMessage(e, 'Export failed'))
|
||||
} finally {
|
||||
exporting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function confirmDelete() {
|
||||
deleting.value = true
|
||||
deleteError.value = null
|
||||
try {
|
||||
await useApi('/me', { method: 'DELETE' })
|
||||
// Soft-delete revoked our refresh token; clear local session and
|
||||
// bounce to the marketing landing.
|
||||
auth.clearSession()
|
||||
await router.push('/')
|
||||
} catch (e: any) {
|
||||
deleteError.value = useErrMessage(e, 'Could not delete account')
|
||||
} finally {
|
||||
deleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function showToast(text: string) {
|
||||
toast.value = text
|
||||
if (toastTimer) clearTimeout(toastTimer)
|
||||
toastTimer = setTimeout(() => { toast.value = null }, 5000)
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
await billing.fetchStatus()
|
||||
|
||||
// Handle return-from-Stripe query params. ?billing=success means the
|
||||
// checkout completed; Stripe also fires the webhook server-side so we
|
||||
// refetch status to pick up the new tier without a hard reload.
|
||||
const flag = route.query.billing
|
||||
if (flag === 'success') {
|
||||
showToast('Subscription updated — welcome aboard!')
|
||||
// Stripe's webhook may take ~1s to land. Poll a couple of times.
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await new Promise((r) => setTimeout(r, 1500))
|
||||
await billing.fetchStatus()
|
||||
if (billing.status.value?.tier !== 'free') break
|
||||
}
|
||||
} else if (flag === 'cancelled') {
|
||||
showToast('No worries — your plan is unchanged.')
|
||||
}
|
||||
})
|
||||
|
||||
async function upgrade(tier: 'pro' | 'business') {
|
||||
switching.value = tier
|
||||
error.value = null
|
||||
try {
|
||||
await billing.startCheckout(tier)
|
||||
} catch (e: any) {
|
||||
if (e?.response?.status === 503) {
|
||||
error.value = 'Billing isn\'t enabled on this environment yet — contact support.'
|
||||
} else {
|
||||
error.value = useErrMessage(e, 'Could not start checkout')
|
||||
}
|
||||
} finally {
|
||||
switching.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function manageSubscription() {
|
||||
portalLoading.value = true
|
||||
error.value = null
|
||||
try {
|
||||
await billing.openPortal()
|
||||
} catch (e: any) {
|
||||
if (e?.response?.status === 503) {
|
||||
error.value = 'Billing isn\'t enabled on this environment yet.'
|
||||
} else {
|
||||
error.value = useErrMessage(e, 'Could not open the billing portal')
|
||||
}
|
||||
} finally {
|
||||
portalLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Usage bar percentage — clamps to [0, 100] for the progress indicator.
|
||||
const eventsUsagePct = computed(() => {
|
||||
const s = billing.status.value
|
||||
if (!s) return 0
|
||||
const limit = s.limits.events_per_month
|
||||
if (limit < 0) return 0 // unlimited — show empty bar
|
||||
if (limit === 0) return 100
|
||||
return Math.min(100, Math.round((s.usage.events_this_month / limit) * 100))
|
||||
})
|
||||
|
||||
function formatLimit(n: number): string {
|
||||
return n < 0 ? 'Unlimited' : n.toLocaleString()
|
||||
}
|
||||
|
||||
function periodEndLabel(iso?: string): string {
|
||||
if (!iso) return ''
|
||||
try {
|
||||
return new Date(iso).toLocaleDateString(undefined, { day: 'numeric', month: 'short', year: 'numeric' })
|
||||
} catch {
|
||||
return iso
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="space-y-6">
|
||||
<div>
|
||||
<NuxtLink to="/dashboard" class="mb-2 inline-block text-sm text-zinc-400 hover:text-zinc-200">
|
||||
← Back to dashboard
|
||||
</NuxtLink>
|
||||
<h1 class="text-2xl font-semibold">Billing & plan</h1>
|
||||
<p class="mt-1 text-sm text-zinc-400">
|
||||
Change your plan, see your usage, or update your payment method.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<ClientOnly>
|
||||
<!-- Current plan + usage -->
|
||||
<div class="card">
|
||||
<div class="flex flex-wrap items-start justify-between gap-4">
|
||||
<div>
|
||||
<p class="text-xs font-medium uppercase tracking-wider text-zinc-500">Current plan</p>
|
||||
<div class="mt-1 flex items-baseline gap-2">
|
||||
<span class="text-2xl font-semibold capitalize text-zinc-100">{{ billing.status.value?.tier || '—' }}</span>
|
||||
<span
|
||||
v-if="billing.status.value && billing.status.value.tier !== 'free'"
|
||||
class="rounded-full bg-zinc-800 px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide text-zinc-400"
|
||||
>{{ billing.status.value.status }}</span>
|
||||
</div>
|
||||
<p
|
||||
v-if="billing.status.value?.current_period_end && billing.status.value.tier !== 'free'"
|
||||
class="mt-1 text-xs text-zinc-500"
|
||||
>
|
||||
<template v-if="billing.status.value.cancel_at_period_end">
|
||||
Cancels on {{ periodEndLabel(billing.status.value.current_period_end) }}.
|
||||
</template>
|
||||
<template v-else>
|
||||
Renews on {{ periodEndLabel(billing.status.value.current_period_end) }}.
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
v-if="billing.status.value?.portal_available"
|
||||
type="button"
|
||||
class="rounded-md border border-zinc-700 px-3 py-1.5 text-sm text-zinc-200 transition hover:border-zinc-500 hover:bg-zinc-800 disabled:opacity-50"
|
||||
:disabled="portalLoading"
|
||||
@click="manageSubscription"
|
||||
>
|
||||
{{ portalLoading ? 'Opening…' : 'Manage subscription' }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Usage bar -->
|
||||
<div class="mt-5">
|
||||
<div class="mb-1.5 flex items-center justify-between text-xs">
|
||||
<span class="text-zinc-300">Events this month</span>
|
||||
<span class="tabular-nums text-zinc-400">
|
||||
{{ billing.status.value?.usage.events_this_month ?? 0 }}
|
||||
of
|
||||
{{ formatLimit(billing.status.value?.limits.events_per_month ?? 0) }}
|
||||
</span>
|
||||
</div>
|
||||
<div class="h-2 w-full overflow-hidden rounded-full bg-zinc-800">
|
||||
<div
|
||||
class="h-full rounded-full transition-all"
|
||||
:class="eventsUsagePct >= 90 ? 'bg-amber-400' : 'bg-brand-500'"
|
||||
:style="{ width: `${eventsUsagePct}%` }"
|
||||
></div>
|
||||
</div>
|
||||
<p class="mt-1.5 text-xs text-zinc-500">
|
||||
Guest cap per event: {{ formatLimit(billing.status.value?.limits.guests_per_event ?? 0) }}.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Pricing cards -->
|
||||
<div class="grid gap-4 md:grid-cols-3">
|
||||
<div
|
||||
v-for="t in TIER_CARDS"
|
||||
:key="t.id"
|
||||
class="card relative flex flex-col gap-4"
|
||||
:class="t.highlight ? 'border-brand-700/60 bg-brand-500/[0.04]' : ''"
|
||||
>
|
||||
<span
|
||||
v-if="t.id === billing.status.value?.tier"
|
||||
class="absolute right-3 top-3 rounded-full bg-zinc-800 px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide text-zinc-300"
|
||||
>Current</span>
|
||||
<span
|
||||
v-else-if="t.highlight"
|
||||
class="absolute right-3 top-3 rounded-full bg-brand-500 px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide text-zinc-950"
|
||||
>Most popular</span>
|
||||
|
||||
<div>
|
||||
<h3 class="text-lg font-semibold capitalize text-zinc-100">{{ t.name }}</h3>
|
||||
<p class="mt-1 text-xs text-zinc-500">{{ t.tagline }}</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<span class="text-3xl font-semibold tabular-nums text-zinc-100">{{ t.price }}</span>
|
||||
<span class="ml-1 text-xs text-zinc-500">{{ t.priceSubtitle }}</span>
|
||||
</div>
|
||||
|
||||
<ul class="space-y-1.5 text-sm text-zinc-300">
|
||||
<li v-for="f in t.features" :key="f" class="flex items-start gap-2">
|
||||
<svg class="mt-0.5 h-3.5 w-3.5 shrink-0 text-brand-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
|
||||
<path fill-rule="evenodd" d="M16.704 5.296a1 1 0 010 1.408l-8 8a1 1 0 01-1.408 0l-4-4a1 1 0 011.408-1.408L8 12.592l7.296-7.296a1 1 0 011.408 0z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
<span>{{ f }}</span>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<div class="mt-auto pt-2">
|
||||
<button
|
||||
v-if="t.id === billing.status.value?.tier"
|
||||
type="button"
|
||||
class="w-full cursor-default rounded-md border border-zinc-800 px-3 py-2 text-sm text-zinc-500"
|
||||
disabled
|
||||
>Current plan</button>
|
||||
<button
|
||||
v-else-if="t.id === 'free'"
|
||||
type="button"
|
||||
class="w-full cursor-default rounded-md border border-zinc-800 px-3 py-2 text-sm text-zinc-500"
|
||||
disabled
|
||||
>Downgrade in the billing portal</button>
|
||||
<button
|
||||
v-else
|
||||
type="button"
|
||||
class="btn-primary w-full disabled:opacity-50"
|
||||
:disabled="switching === t.id"
|
||||
@click="upgrade(t.id)"
|
||||
>
|
||||
{{ switching === t.id ? 'Opening checkout…' : `Upgrade to ${t.name}` }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
|
||||
|
||||
<p class="text-xs text-zinc-500">
|
||||
Receipts and invoices are emailed automatically by Stripe.
|
||||
Need to cancel? Use <a href="#" class="text-brand-400 hover:text-brand-300" @click.prevent="manageSubscription">Manage subscription</a> above.
|
||||
</p>
|
||||
|
||||
<!-- ===== Your data ===== -->
|
||||
<div class="card mt-2">
|
||||
<h2 class="mb-1 text-lg font-semibold">Your data</h2>
|
||||
<p class="mb-4 text-xs text-zinc-500">
|
||||
Export a copy of everything we hold about you, or delete your account.
|
||||
</p>
|
||||
<div class="space-y-3">
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between rounded-md border border-zinc-700 bg-zinc-950 px-3 py-3 text-left transition hover:border-zinc-500 hover:bg-zinc-900 disabled:opacity-50"
|
||||
:disabled="exporting"
|
||||
@click="exportData"
|
||||
>
|
||||
<span>
|
||||
<span class="block text-sm font-medium text-zinc-100">Export my data</span>
|
||||
<span class="block text-xs text-zinc-500">
|
||||
Download a JSON file with your events, guests, RSVPs, and account info.
|
||||
</span>
|
||||
</span>
|
||||
<span class="text-xs text-zinc-400">{{ exporting ? '…' : '↓' }}</span>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="flex w-full items-center justify-between rounded-md border border-red-800/40 bg-red-950/10 px-3 py-3 text-left transition hover:border-red-700 hover:bg-red-950/20"
|
||||
@click="deleteConfirmOpen = true"
|
||||
>
|
||||
<span>
|
||||
<span class="block text-sm font-medium text-red-300">Delete my account</span>
|
||||
<span class="block text-xs text-red-400/70">
|
||||
Soft-deleted immediately, permanently erased after 30 days. You'll be signed out everywhere.
|
||||
</span>
|
||||
</span>
|
||||
<span class="text-xs text-red-400">→</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #fallback>
|
||||
<div class="card text-sm text-zinc-500">Loading…</div>
|
||||
</template>
|
||||
</ClientOnly>
|
||||
|
||||
<!-- Delete-account confirmation -->
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="deleteConfirmOpen"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/60 p-4 backdrop-blur-sm"
|
||||
@click.self="deleteConfirmOpen = false"
|
||||
>
|
||||
<div
|
||||
role="alertdialog"
|
||||
aria-modal="true"
|
||||
aria-labelledby="del-acct-title"
|
||||
class="w-full max-w-md rounded-lg border border-zinc-800 bg-zinc-900 p-5 shadow-2xl"
|
||||
>
|
||||
<h3 id="del-acct-title" class="mb-1 text-base font-semibold">Delete account?</h3>
|
||||
<p class="mb-3 text-sm text-zinc-400">
|
||||
Your account will be soft-deleted now and permanently erased
|
||||
after 30 days. All your events, guests, and RSVP history go
|
||||
with it. You'll be signed out from every device.
|
||||
</p>
|
||||
<p class="mb-3 text-xs text-zinc-500">
|
||||
Type <code class="rounded bg-zinc-800 px-1 py-0.5 font-mono text-zinc-300">delete</code>
|
||||
to confirm.
|
||||
</p>
|
||||
<input
|
||||
v-model="deleteConfirmation"
|
||||
type="text"
|
||||
placeholder="delete"
|
||||
class="input mb-3 font-mono"
|
||||
autocomplete="off"
|
||||
/>
|
||||
<div class="flex items-center justify-end gap-2">
|
||||
<button class="text-sm text-zinc-400 hover:text-zinc-200" :disabled="deleting" @click="deleteConfirmOpen = false">Cancel</button>
|
||||
<button
|
||||
class="rounded-md bg-red-500/90 px-3 py-1.5 text-sm font-medium text-white shadow-sm transition hover:bg-red-500 disabled:opacity-40"
|
||||
:disabled="deleting || deleteConfirmation.trim().toLowerCase() !== 'delete'"
|
||||
@click="confirmDelete"
|
||||
>
|
||||
{{ deleting ? 'Deleting…' : 'Delete forever' }}
|
||||
</button>
|
||||
</div>
|
||||
<p v-if="deleteError" class="mt-3 text-sm text-red-400">{{ deleteError }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
|
||||
<!-- Toast for return-from-Stripe -->
|
||||
<Transition
|
||||
enter-active-class="transition duration-200 ease-out"
|
||||
enter-from-class="translate-y-2 opacity-0"
|
||||
enter-to-class="translate-y-0 opacity-100"
|
||||
leave-active-class="transition duration-200 ease-in"
|
||||
leave-from-class="translate-y-0 opacity-100"
|
||||
leave-to-class="translate-y-2 opacity-0"
|
||||
>
|
||||
<button
|
||||
v-if="toast"
|
||||
type="button"
|
||||
class="fixed bottom-6 right-6 z-50 max-w-sm rounded-lg border border-brand-700/60 bg-brand-950/90 px-4 py-3 text-left text-sm text-brand-100 shadow-lg backdrop-blur"
|
||||
@click="toast = null"
|
||||
>
|
||||
<span aria-hidden="true" class="mr-2">✓</span>{{ toast }}
|
||||
</button>
|
||||
</Transition>
|
||||
</section>
|
||||
</template>
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,8 @@
|
||||
<script setup lang="ts">
|
||||
const { host } = useHost()
|
||||
definePageMeta({ middleware: ['auth'] })
|
||||
|
||||
const auth = useAuth()
|
||||
const host = auth.user
|
||||
|
||||
const name = ref('')
|
||||
const slug = ref('')
|
||||
@@ -17,7 +20,6 @@ async function submit() {
|
||||
const created = await useApi<{ id: string }>('/events', {
|
||||
method: 'POST',
|
||||
body: {
|
||||
host_id: host.value.id,
|
||||
name: name.value,
|
||||
slug: slug.value,
|
||||
event_date: new Date(eventDate.value).toISOString(),
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
<script setup lang="ts">
|
||||
definePageMeta({ middleware: ['auth'] })
|
||||
|
||||
interface EventSummary {
|
||||
id: string
|
||||
name: string
|
||||
@@ -13,40 +15,28 @@ interface EventsResponse {
|
||||
events: EventSummary[]
|
||||
}
|
||||
|
||||
const { host, bootstrap } = useHost()
|
||||
|
||||
const email = ref('')
|
||||
const name = ref('')
|
||||
const bootstrapping = ref(false)
|
||||
const bootstrapError = ref<string | null>(null)
|
||||
|
||||
async function onBootstrap() {
|
||||
bootstrapError.value = null
|
||||
bootstrapping.value = true
|
||||
try {
|
||||
await bootstrap(email.value, name.value)
|
||||
} catch (e: any) {
|
||||
bootstrapError.value = e?.data?.error || e?.message || 'Failed to bootstrap'
|
||||
} finally {
|
||||
bootstrapping.value = false
|
||||
}
|
||||
}
|
||||
const auth = useAuth()
|
||||
|
||||
const events = ref<EventSummary[]>([])
|
||||
const loadingEvents = ref(false)
|
||||
const loadError = ref<string | null>(null)
|
||||
|
||||
async function loadEvents() {
|
||||
if (!host.value) return
|
||||
if (!auth.user.value) return
|
||||
loadingEvents.value = true
|
||||
loadError.value = null
|
||||
try {
|
||||
const res = await useApi<EventsResponse>('/events', { query: { host_id: host.value.id } })
|
||||
// host is derived server-side from the session — no query param needed.
|
||||
const res = await useApi<EventsResponse>('/events')
|
||||
events.value = res.events
|
||||
} catch (e: any) {
|
||||
loadError.value = e?.data?.error || e?.message || 'Failed to load events'
|
||||
} finally {
|
||||
loadingEvents.value = false
|
||||
}
|
||||
}
|
||||
|
||||
watch(host, loadEvents, { immediate: true })
|
||||
watch(() => auth.user.value, loadEvents, { immediate: true })
|
||||
|
||||
function fmtDate(iso: string) {
|
||||
try { return new Date(iso).toLocaleString() } catch { return iso }
|
||||
@@ -55,40 +45,22 @@ function fmtDate(iso: string) {
|
||||
|
||||
<template>
|
||||
<section>
|
||||
<!--
|
||||
The dashboard is auth-gated by a localStorage-backed host. Rendering
|
||||
that conditional on the server (where there's no localStorage) and
|
||||
then again on the client (where there is) causes a hydration
|
||||
mismatch that leaves the layout stuck at the bootstrap card's width
|
||||
after a hard refresh. Skipping SSR for this block fixes both the
|
||||
flash and the layout shrink.
|
||||
-->
|
||||
<ClientOnly>
|
||||
<div v-if="!host" class="card max-w-md">
|
||||
<h1 class="mb-2 text-xl font-semibold">Get started</h1>
|
||||
<p class="mb-4 text-sm text-zinc-400">
|
||||
Demo bootstrap — enter an email + name to provision a host. We don't store passwords.
|
||||
</p>
|
||||
<label class="label">Email</label>
|
||||
<input v-model="email" type="email" class="input mb-3" placeholder="you@example.com" />
|
||||
<label class="label">Name</label>
|
||||
<input v-model="name" type="text" class="input mb-4" placeholder="Your name" />
|
||||
<button class="btn-primary w-full" :disabled="bootstrapping || !email || !name" @click="onBootstrap">
|
||||
{{ bootstrapping ? 'Setting up…' : 'Continue' }}
|
||||
</button>
|
||||
<p v-if="bootstrapError" class="mt-3 text-sm text-red-400">{{ bootstrapError }}</p>
|
||||
<div v-if="!auth.bootstrapped.value || !auth.user.value" class="text-sm text-zinc-500">
|
||||
Loading dashboard…
|
||||
</div>
|
||||
|
||||
<div v-else>
|
||||
<div class="mb-6 flex items-center justify-between">
|
||||
<div>
|
||||
<h1 class="text-2xl font-semibold">Your events</h1>
|
||||
<p class="text-sm text-zinc-400">Signed in as {{ host.name }} ({{ host.email }})</p>
|
||||
<p class="text-sm text-zinc-400">Signed in as {{ auth.user.value.name }} ({{ auth.user.value.email }})</p>
|
||||
</div>
|
||||
<NuxtLink to="/dashboard/events/new" class="btn-primary">New event</NuxtLink>
|
||||
</div>
|
||||
|
||||
<div v-if="loadingEvents" class="text-sm text-zinc-500">Loading…</div>
|
||||
<div v-else-if="loadError" class="card text-sm text-red-400">{{ loadError }}</div>
|
||||
<div v-else-if="events.length === 0" class="card text-sm text-zinc-400">
|
||||
No events yet. Create one to get started.
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
<script setup lang="ts">
|
||||
const auth = useAuth()
|
||||
|
||||
const email = ref('')
|
||||
const submitting = ref(false)
|
||||
const sent = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
async function submit() {
|
||||
error.value = null
|
||||
submitting.value = true
|
||||
try {
|
||||
await auth.forgotPassword(email.value)
|
||||
sent.value = true
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Request failed')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="mx-auto max-w-md py-12">
|
||||
<h1 class="mb-2 text-2xl font-semibold">Forgot your password?</h1>
|
||||
<p class="mb-6 text-sm text-zinc-400">
|
||||
Enter your email and we'll send a reset link if there's an account on file.
|
||||
</p>
|
||||
|
||||
<div v-if="sent" class="card text-sm">
|
||||
<p class="mb-2 font-medium text-brand-300">Check your inbox.</p>
|
||||
<p class="text-zinc-400">
|
||||
If <span class="text-zinc-200">{{ email }}</span> is registered, a reset link is on its way.
|
||||
</p>
|
||||
<NuxtLink to="/login" class="btn-ghost mt-4 w-full">Back to sign in</NuxtLink>
|
||||
</div>
|
||||
|
||||
<form v-else class="card space-y-4" @submit.prevent="submit">
|
||||
<div>
|
||||
<label class="label">Email</label>
|
||||
<input v-model="email" type="email" class="input" autocomplete="email" required />
|
||||
</div>
|
||||
<button class="btn-primary w-full" :disabled="submitting || !email">
|
||||
{{ submitting ? 'Sending…' : 'Send reset link' }}
|
||||
</button>
|
||||
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
|
||||
</form>
|
||||
</section>
|
||||
</template>
|
||||
@@ -0,0 +1,53 @@
|
||||
<script setup lang="ts">
|
||||
const auth = useAuth()
|
||||
const route = useRoute()
|
||||
|
||||
const email = ref('')
|
||||
const password = ref('')
|
||||
const submitting = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
async function submit() {
|
||||
error.value = null
|
||||
submitting.value = true
|
||||
try {
|
||||
await auth.login(email.value, password.value)
|
||||
const redirect = (route.query.redirect as string) || '/dashboard'
|
||||
await navigateTo(redirect)
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Login failed')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="mx-auto max-w-md py-12">
|
||||
<h1 class="mb-2 text-2xl font-semibold">Sign in</h1>
|
||||
<p class="mb-6 text-sm text-zinc-400">Welcome back. Sign in to manage your events.</p>
|
||||
|
||||
<form class="card space-y-4" @submit.prevent="submit">
|
||||
<div>
|
||||
<label class="label">Email</label>
|
||||
<input v-model="email" type="email" class="input" autocomplete="email" required />
|
||||
</div>
|
||||
<div>
|
||||
<label class="label">Password</label>
|
||||
<input v-model="password" type="password" class="input" autocomplete="current-password" required />
|
||||
<div class="mt-1 text-right text-xs">
|
||||
<NuxtLink to="/forgot-password" class="text-zinc-400 hover:text-zinc-200">Forgot password?</NuxtLink>
|
||||
</div>
|
||||
</div>
|
||||
<button class="btn-primary w-full" :disabled="submitting || !email || !password">
|
||||
{{ submitting ? 'Signing in…' : 'Sign in' }}
|
||||
</button>
|
||||
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
|
||||
</form>
|
||||
|
||||
<p class="mt-6 text-center text-sm text-zinc-400">
|
||||
Don't have an account?
|
||||
<NuxtLink to="/signup" class="text-brand-400 hover:text-brand-300">Sign up</NuxtLink>
|
||||
</p>
|
||||
</section>
|
||||
</template>
|
||||
@@ -0,0 +1,51 @@
|
||||
<script setup lang="ts">
|
||||
useHead({ title: 'Privacy policy · GuestGuard' })
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<article class="prose prose-invert mx-auto max-w-2xl py-8 text-zinc-300">
|
||||
<h1 class="text-2xl font-semibold text-zinc-100">Privacy policy</h1>
|
||||
<p class="text-sm text-zinc-500">Last updated: <strong>placeholder — pending legal review</strong></p>
|
||||
|
||||
<p>
|
||||
This page is a placeholder while we have proper privacy copy
|
||||
reviewed by a lawyer. The substance below reflects how the
|
||||
product actually handles data today — the final wording will
|
||||
replace this page before public launch.
|
||||
</p>
|
||||
|
||||
<h2 class="mt-6 text-lg font-semibold text-zinc-100">What we collect</h2>
|
||||
<ul class="list-disc pl-6">
|
||||
<li><strong>Host account</strong>: email, name, hashed password.</li>
|
||||
<li><strong>Guest list</strong>: names, emails, phone numbers — only what
|
||||
the host enters or imports.</li>
|
||||
<li><strong>RSVP responses</strong>: the answer the guest sends back.</li>
|
||||
<li><strong>Access logs</strong>: IP, device fingerprint, and a fraud
|
||||
risk score, for each invitation-link open. Used to flag suspicious
|
||||
access (e.g. someone forwarded the link).</li>
|
||||
<li><strong>Billing</strong>: handled by Stripe — we store only the
|
||||
Stripe customer + subscription IDs locally, never card details.</li>
|
||||
</ul>
|
||||
|
||||
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Your rights</h2>
|
||||
<ul class="list-disc pl-6">
|
||||
<li><strong>Export</strong>: download a full JSON dump of your data via
|
||||
Settings → "Export my data".</li>
|
||||
<li><strong>Delete</strong>: delete your account via Settings → "Delete
|
||||
account". Your row is soft-deleted immediately and hard-deleted 30
|
||||
days later (kept briefly in case of accidental clicks).</li>
|
||||
<li><strong>Question</strong>: email
|
||||
<a href="mailto:privacy@gg.k4scloud.com" class="text-brand-400">privacy@gg.k4scloud.com</a>.</li>
|
||||
</ul>
|
||||
|
||||
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Sub-processors</h2>
|
||||
<p>We use these third parties to run the service:</p>
|
||||
<ul class="list-disc pl-6">
|
||||
<li><strong>Stripe</strong> — billing</li>
|
||||
<li><strong>Resend</strong> — transactional email delivery</li>
|
||||
<li><strong>AWS S3</strong> — encrypted database backups</li>
|
||||
</ul>
|
||||
|
||||
<NuxtLink to="/" class="mt-8 inline-block text-sm text-brand-400 hover:text-brand-300">← Back home</NuxtLink>
|
||||
</article>
|
||||
</template>
|
||||
@@ -0,0 +1,73 @@
|
||||
<script setup lang="ts">
|
||||
const auth = useAuth()
|
||||
const route = useRoute()
|
||||
|
||||
const password = ref('')
|
||||
const confirm = ref('')
|
||||
const submitting = ref(false)
|
||||
const done = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
const token = computed(() => String(route.params.token || ''))
|
||||
|
||||
async function submit() {
|
||||
error.value = null
|
||||
if (password.value !== confirm.value) {
|
||||
error.value = 'Passwords do not match.'
|
||||
return
|
||||
}
|
||||
submitting.value = true
|
||||
try {
|
||||
await auth.resetPassword(token.value, password.value)
|
||||
done.value = true
|
||||
} catch (e: any) {
|
||||
error.value = e?.data?.error || e?.message || 'Reset failed'
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="mx-auto max-w-md py-12">
|
||||
<h1 class="mb-6 text-2xl font-semibold">Choose a new password</h1>
|
||||
|
||||
<div v-if="done" class="card text-sm">
|
||||
<p class="mb-2 font-medium text-brand-300">Password updated.</p>
|
||||
<p class="mb-4 text-zinc-400">All previous sessions have been signed out. Sign in to continue.</p>
|
||||
<NuxtLink to="/login" class="btn-primary w-full">Sign in</NuxtLink>
|
||||
</div>
|
||||
|
||||
<form v-else class="card space-y-4" @submit.prevent="submit">
|
||||
<div>
|
||||
<label class="label">New password</label>
|
||||
<input
|
||||
v-model="password"
|
||||
type="password"
|
||||
class="input"
|
||||
autocomplete="new-password"
|
||||
minlength="8"
|
||||
maxlength="72"
|
||||
required
|
||||
/>
|
||||
<p class="mt-1 text-xs text-zinc-500">At least 8 characters.</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="label">Confirm password</label>
|
||||
<input
|
||||
v-model="confirm"
|
||||
type="password"
|
||||
class="input"
|
||||
autocomplete="new-password"
|
||||
minlength="8"
|
||||
maxlength="72"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<button class="btn-primary w-full" :disabled="submitting || password.length < 8 || !confirm">
|
||||
{{ submitting ? 'Updating…' : 'Update password' }}
|
||||
</button>
|
||||
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
|
||||
</form>
|
||||
</section>
|
||||
</template>
|
||||
@@ -0,0 +1,90 @@
|
||||
<script setup lang="ts">
|
||||
const auth = useAuth()
|
||||
|
||||
const email = ref('')
|
||||
const name = ref('')
|
||||
const password = ref('')
|
||||
const acceptTerms = ref(false)
|
||||
const submitting = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const sent = ref(false)
|
||||
|
||||
async function submit() {
|
||||
error.value = null
|
||||
submitting.value = true
|
||||
try {
|
||||
await auth.signup(email.value, name.value, password.value, acceptTerms.value)
|
||||
sent.value = true
|
||||
} catch (e: any) {
|
||||
error.value = useErrMessage(e, 'Sign-up failed')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="mx-auto max-w-md py-12">
|
||||
<h1 class="mb-2 text-2xl font-semibold">Create your account</h1>
|
||||
<p class="mb-6 text-sm text-zinc-400">Start managing your event guest lists in minutes.</p>
|
||||
|
||||
<div v-if="sent" class="card text-sm">
|
||||
<p class="mb-2 font-medium text-brand-300">Check your inbox.</p>
|
||||
<p class="text-zinc-400">
|
||||
If <span class="text-zinc-200">{{ email }}</span> is reachable, we've sent a verification link.
|
||||
Click it to finish setting up your account.
|
||||
</p>
|
||||
<NuxtLink to="/login" class="btn-ghost mt-4 w-full">Back to sign in</NuxtLink>
|
||||
</div>
|
||||
|
||||
<form v-else class="card space-y-4" @submit.prevent="submit">
|
||||
<div>
|
||||
<label class="label">Name</label>
|
||||
<input v-model="name" type="text" class="input" autocomplete="name" required />
|
||||
</div>
|
||||
<div>
|
||||
<label class="label">Email</label>
|
||||
<input v-model="email" type="email" class="input" autocomplete="email" required />
|
||||
</div>
|
||||
<div>
|
||||
<label class="label">Password</label>
|
||||
<input
|
||||
v-model="password"
|
||||
type="password"
|
||||
class="input"
|
||||
autocomplete="new-password"
|
||||
minlength="8"
|
||||
maxlength="72"
|
||||
required
|
||||
/>
|
||||
<p class="mt-1 text-xs text-zinc-500">At least 8 characters.</p>
|
||||
</div>
|
||||
<label class="flex cursor-pointer items-start gap-2 text-xs text-zinc-400">
|
||||
<input
|
||||
v-model="acceptTerms"
|
||||
type="checkbox"
|
||||
class="mt-0.5 h-4 w-4 cursor-pointer accent-brand-500"
|
||||
required
|
||||
/>
|
||||
<span>
|
||||
I agree to GuestGuard's
|
||||
<NuxtLink to="/terms" target="_blank" class="text-brand-400 hover:text-brand-300">Terms of Service</NuxtLink>
|
||||
and
|
||||
<NuxtLink to="/privacy" target="_blank" class="text-brand-400 hover:text-brand-300">Privacy Policy</NuxtLink>.
|
||||
</span>
|
||||
</label>
|
||||
<button
|
||||
class="btn-primary w-full"
|
||||
:disabled="submitting || !email || !name || password.length < 8 || !acceptTerms"
|
||||
>
|
||||
{{ submitting ? 'Creating…' : 'Create account' }}
|
||||
</button>
|
||||
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
|
||||
</form>
|
||||
|
||||
<p class="mt-6 text-center text-sm text-zinc-400">
|
||||
Already have an account?
|
||||
<NuxtLink to="/login" class="text-brand-400 hover:text-brand-300">Sign in</NuxtLink>
|
||||
</p>
|
||||
</section>
|
||||
</template>
|
||||
@@ -0,0 +1,50 @@
|
||||
<script setup lang="ts">
|
||||
useHead({ title: 'Terms of service · GuestGuard' })
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<article class="prose prose-invert mx-auto max-w-2xl py-8 text-zinc-300">
|
||||
<h1 class="text-2xl font-semibold text-zinc-100">Terms of service</h1>
|
||||
<p class="text-sm text-zinc-500">Last updated: <strong>placeholder — pending legal review</strong></p>
|
||||
|
||||
<p>
|
||||
Placeholder copy while a lawyer drafts the real document. The points
|
||||
below capture the substance of how we'd like the relationship to
|
||||
work — the binding version will replace this page before public
|
||||
launch.
|
||||
</p>
|
||||
|
||||
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Summary</h2>
|
||||
<ul class="list-disc pl-6">
|
||||
<li>You're responsible for your guest list — make sure you have
|
||||
permission to message the people on it.</li>
|
||||
<li>We're responsible for keeping the service running, your data
|
||||
backed up, and our handling of it transparent (see the privacy
|
||||
page).</li>
|
||||
<li>Either of us can end the relationship — you by deleting your
|
||||
account, us with reasonable notice if we have to wind down.</li>
|
||||
<li>The service is provided "as-is" — we don't promise zero downtime
|
||||
or that no spam-filter on earth will misfilter your invitations.
|
||||
We try hard, but we're not insurers.</li>
|
||||
</ul>
|
||||
|
||||
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Acceptable use</h2>
|
||||
<ul class="list-disc pl-6">
|
||||
<li>Don't use GuestGuard to send unsolicited bulk mail (spam).</li>
|
||||
<li>Don't use it to harass, deceive, or impersonate.</li>
|
||||
<li>Don't poke at the security of the service — if you find a bug,
|
||||
please tell us at
|
||||
<a href="mailto:security@gg.k4scloud.com" class="text-brand-400">security@gg.k4scloud.com</a>.</li>
|
||||
</ul>
|
||||
|
||||
<h2 class="mt-6 text-lg font-semibold text-zinc-100">Payment</h2>
|
||||
<p>
|
||||
Subscriptions auto-renew until cancelled. Cancel any time from
|
||||
Billing → Manage subscription. Refunds: we'll consider them case by
|
||||
case — email
|
||||
<a href="mailto:support@gg.k4scloud.com" class="text-brand-400">support@gg.k4scloud.com</a>.
|
||||
</p>
|
||||
|
||||
<NuxtLink to="/" class="mt-8 inline-block text-sm text-brand-400 hover:text-brand-300">← Back home</NuxtLink>
|
||||
</article>
|
||||
</template>
|
||||
@@ -0,0 +1,69 @@
|
||||
<script setup lang="ts">
|
||||
const route = useRoute()
|
||||
const config = useRuntimeConfig()
|
||||
const apiBase = config.public.apiBase as string
|
||||
|
||||
const token = computed(() => String(route.params.token || ''))
|
||||
const status = ref<'loading' | 'ready' | 'done' | 'error'>('loading')
|
||||
const email = ref('')
|
||||
const error = ref<string | null>(null)
|
||||
const submitting = ref(false)
|
||||
|
||||
onMounted(async () => {
|
||||
try {
|
||||
const r = await $fetch<{ email: string }>(`/unsubscribe/${token.value}`, { baseURL: apiBase })
|
||||
email.value = r.email
|
||||
status.value = 'ready'
|
||||
} catch (e: any) {
|
||||
status.value = 'error'
|
||||
error.value = e?.data?.error || 'This link is invalid or has expired.'
|
||||
}
|
||||
})
|
||||
|
||||
async function confirm() {
|
||||
submitting.value = true
|
||||
error.value = null
|
||||
try {
|
||||
await $fetch(`/unsubscribe/${token.value}`, { baseURL: apiBase, method: 'POST' })
|
||||
status.value = 'done'
|
||||
} catch (e: any) {
|
||||
error.value = e?.data?.error || 'Something went wrong — please try again.'
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="mx-auto max-w-md py-12">
|
||||
<h1 class="mb-6 text-2xl font-semibold">Unsubscribe</h1>
|
||||
|
||||
<div v-if="status === 'loading'" class="card text-sm text-zinc-400">Loading…</div>
|
||||
|
||||
<div v-else-if="status === 'ready'" class="card space-y-4 text-sm">
|
||||
<p>You'll stop receiving GuestGuard emails sent to:</p>
|
||||
<p class="font-mono text-zinc-200">{{ email }}</p>
|
||||
<p class="text-zinc-400">
|
||||
This includes RSVPs, reminders, and host messages. Account-related emails
|
||||
(security, password resets) will still reach you.
|
||||
</p>
|
||||
<button class="btn-primary w-full" :disabled="submitting" @click="confirm">
|
||||
{{ submitting ? 'Unsubscribing…' : 'Unsubscribe me' }}
|
||||
</button>
|
||||
<p v-if="error" class="text-sm text-red-400">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<div v-else-if="status === 'done'" class="card text-sm">
|
||||
<p class="mb-2 font-medium text-brand-300">Done.</p>
|
||||
<p class="text-zinc-400">
|
||||
We've added <span class="text-zinc-200">{{ email }}</span> to our suppression list.
|
||||
Future emails to that address will be silently dropped.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-else class="card text-sm">
|
||||
<p class="mb-2 font-medium text-red-400">Can't unsubscribe</p>
|
||||
<p class="text-zinc-400">{{ error }}</p>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
@@ -0,0 +1,44 @@
|
||||
<script setup lang="ts">
|
||||
const auth = useAuth()
|
||||
const route = useRoute()
|
||||
|
||||
type Status = 'pending' | 'success' | 'error'
|
||||
const status = ref<Status>('pending')
|
||||
const error = ref<string | null>(null)
|
||||
|
||||
onMounted(async () => {
|
||||
const token = route.query.token
|
||||
if (typeof token !== 'string' || !token) {
|
||||
status.value = 'error'
|
||||
error.value = 'Missing verification token.'
|
||||
return
|
||||
}
|
||||
try {
|
||||
await auth.verifyEmail(token)
|
||||
status.value = 'success'
|
||||
} catch (e: any) {
|
||||
status.value = 'error'
|
||||
error.value = e?.data?.error || e?.message || 'Verification failed.'
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section class="mx-auto max-w-md py-12">
|
||||
<h1 class="mb-6 text-2xl font-semibold">Verify your email</h1>
|
||||
|
||||
<div class="card text-sm">
|
||||
<p v-if="status === 'pending'" class="text-zinc-400">Verifying…</p>
|
||||
<template v-else-if="status === 'success'">
|
||||
<p class="mb-3 font-medium text-brand-300">Email verified.</p>
|
||||
<p class="mb-4 text-zinc-400">You can now sign in to your account.</p>
|
||||
<NuxtLink to="/login" class="btn-primary w-full">Sign in</NuxtLink>
|
||||
</template>
|
||||
<template v-else>
|
||||
<p class="mb-3 font-medium text-red-400">We couldn't verify that link.</p>
|
||||
<p class="mb-4 text-zinc-400">{{ error }}</p>
|
||||
<NuxtLink to="/login" class="btn-ghost w-full">Back to sign in</NuxtLink>
|
||||
</template>
|
||||
</div>
|
||||
</section>
|
||||
</template>
|
||||
@@ -4,11 +4,13 @@ go 1.26.2
|
||||
|
||||
require (
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.9.2
|
||||
github.com/nats-io/nats.go v1.52.0
|
||||
github.com/testcontainers/testcontainers-go v0.42.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
google.golang.org/grpc v1.81.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
)
|
||||
@@ -17,6 +19,22 @@ require (
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/alicebob/miniredis/v2 v2.38.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
|
||||
github.com/aws/smithy-go v1.25.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
@@ -33,6 +51,7 @@ require (
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
@@ -52,24 +71,29 @@ require (
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/redis/go-redis/v9 v9.19.0 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.26.3 // indirect
|
||||
github.com/sirupsen/logrus v1.9.4 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/stripe/stripe-go/v82 v82.5.1 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.16 // indirect
|
||||
github.com/tklauser/numcpus v0.11.0 // indirect
|
||||
github.com/twilio/twilio-go v1.30.9 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -6,6 +6,38 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw=
|
||||
github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4 h1:X/PtmuX/EwPivJ9lHCf3Auo8AktdNc4a9ury4zmGPC4=
|
||||
github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.4/go.mod h1:l5cTwZSX9kzxDHz9IpgZC0XIJ/cc43JL6hZzCd0iTwI=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
||||
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
||||
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -44,6 +76,10 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
@@ -101,10 +137,14 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k=
|
||||
github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/shirou/gopsutil/v4 v4.26.3 h1:2ESdQt90yU3oXF/CdOlRCJxrP+Am1aBYubTMTfxJ1qc=
|
||||
@@ -118,6 +158,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/stripe/stripe-go/v82 v82.5.1 h1:05q6ZDKoe8PLMpQV072obF74HCgP4XJeJYoNuRSX2+8=
|
||||
github.com/stripe/stripe-go/v82 v82.5.1/go.mod h1:majCQX6AfObAvJiHraPi/5udwHi4ojRvJnnxckvHrX8=
|
||||
github.com/testcontainers/testcontainers-go v0.42.0 h1:He3IhTzTZOygSXLJPMX7n44XtK+qhjat1nI9cneBbUY=
|
||||
github.com/testcontainers/testcontainers-go v0.42.0/go.mod h1:vZjdY1YmUA1qEForxOIOazfsrdyORJAbhi0bp8plN30=
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 h1:GCbb1ndrF7OTDiIvxXyItaDab4qkzTFJ48LKFdM7EIo=
|
||||
@@ -126,6 +168,11 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI
|
||||
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
|
||||
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
|
||||
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
|
||||
github.com/twilio/twilio-go v1.30.9 h1:4W4GEV2q0sLQ9xsr1N/97JQlt0c82hZ0ij4qTErstv8=
|
||||
github.com/twilio/twilio-go v1.30.9/go.mod h1:QbitvbvtkV77Jn4BABAKVmxabYSjMyQG4tHey9gfPqg=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
@@ -142,22 +189,48 @@ go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfC
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
@@ -42,16 +40,15 @@ type activityItem struct {
|
||||
// for an event, sorted newest first. Frontends use this on dashboard mount
|
||||
// to backfill the live monitor with history.
|
||||
func (h *activityHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := h.events.Get(r.Context(), eventID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,557 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/ratelimit"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
const refreshCookieName = "gg_refresh"
|
||||
|
||||
type authHandler struct {
|
||||
logger *slog.Logger
|
||||
users *storage.UserRepo
|
||||
verifications *storage.EmailVerificationRepo
|
||||
resets *storage.PasswordResetRepo
|
||||
refreshes *storage.RefreshTokenRepo
|
||||
hasher *auth.PasswordHasher
|
||||
signer *auth.JWTSigner
|
||||
emails auth.EmailSender
|
||||
lockout *auth.LockoutTracker
|
||||
limiter *ratelimit.Limiter
|
||||
|
||||
publicBaseURL string
|
||||
emailVerificationTTL time.Duration
|
||||
passwordResetTTL time.Duration
|
||||
refreshTTL time.Duration
|
||||
cookieDomain string
|
||||
cookieSecure bool
|
||||
}
|
||||
|
||||
type authHandlerDeps struct {
|
||||
Logger *slog.Logger
|
||||
Users *storage.UserRepo
|
||||
Verifications *storage.EmailVerificationRepo
|
||||
Resets *storage.PasswordResetRepo
|
||||
Refreshes *storage.RefreshTokenRepo
|
||||
Hasher *auth.PasswordHasher
|
||||
Signer *auth.JWTSigner
|
||||
Emails auth.EmailSender
|
||||
Lockout *auth.LockoutTracker
|
||||
Limiter *ratelimit.Limiter
|
||||
|
||||
PublicBaseURL string
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
RefreshTTL time.Duration
|
||||
CookieDomain string
|
||||
CookieSecure bool
|
||||
}
|
||||
|
||||
func newAuthHandler(d authHandlerDeps) *authHandler {
|
||||
return &authHandler{
|
||||
logger: d.Logger,
|
||||
users: d.Users,
|
||||
verifications: d.Verifications,
|
||||
resets: d.Resets,
|
||||
refreshes: d.Refreshes,
|
||||
hasher: d.Hasher,
|
||||
signer: d.Signer,
|
||||
emails: d.Emails,
|
||||
lockout: d.Lockout,
|
||||
limiter: d.Limiter,
|
||||
publicBaseURL: strings.TrimRight(d.PublicBaseURL, "/"),
|
||||
emailVerificationTTL: d.EmailVerificationTTL,
|
||||
passwordResetTTL: d.PasswordResetTTL,
|
||||
refreshTTL: d.RefreshTTL,
|
||||
cookieDomain: d.CookieDomain,
|
||||
cookieSecure: d.CookieSecure,
|
||||
}
|
||||
}
|
||||
|
||||
// --- request/response types ---
|
||||
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
AcceptTerms bool `json:"accept_terms"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type verifyEmailRequest struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type forgotPasswordRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type resetPasswordRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type authSuccess struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
User *domain.User `json:"user"`
|
||||
}
|
||||
|
||||
// --- handlers ---
|
||||
|
||||
// POST /auth/signup
|
||||
func (h *authHandler) signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "email is invalid")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
hash, err := h.hasher.Hash(req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("hash password", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.Create(r.Context(), storage.CreateUserParams{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
PasswordHash: hash,
|
||||
AcceptTerms: req.AcceptTerms,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEmailTaken) {
|
||||
// Don't leak which addresses are registered. Still return 201 and
|
||||
// trigger a "if-you-already-have-an-account" email asynchronously
|
||||
// (skipped for the stub). On real auth this should send a "you
|
||||
// tried to sign up again, here's a reset link" email.
|
||||
h.logger.Info("signup attempted with existing email", "email", req.Email)
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("create user", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.sendVerificationEmail(r.Context(), u); err != nil {
|
||||
h.logger.Error("send verification email", "err", err, "user_id", u.ID)
|
||||
// Don't fail the signup — user can request a resend.
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
|
||||
}
|
||||
|
||||
// POST /auth/login
|
||||
func (h *authHandler) login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Email == "" || req.Password == "" {
|
||||
writeError(w, http.StatusBadRequest, "email and password required")
|
||||
return
|
||||
}
|
||||
|
||||
// Per-(IP + email) sliding-window — 10 per 5 minutes per the plan.
|
||||
if !h.checkRate(w, r, "login", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
|
||||
10, 5*time.Minute) {
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.GetByEmail(r.Context(), req.Email)
|
||||
if err != nil || u.PasswordHash == "" {
|
||||
_, _ = h.lockout.RecordFailure(r.Context(), req.Email, nil)
|
||||
writeError(w, http.StatusUnauthorized, "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
// If the account is already locked, reject before doing a bcrypt compare.
|
||||
locked, _ := h.lockout.IsLocked(r.Context(), u.ID)
|
||||
if locked {
|
||||
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.hasher.Verify(u.PasswordHash, req.Password); err != nil {
|
||||
locked, _ := h.lockout.RecordFailure(r.Context(), req.Email, &u.ID)
|
||||
if locked {
|
||||
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid email or password")
|
||||
return
|
||||
}
|
||||
if !u.EmailVerified {
|
||||
writeError(w, http.StatusForbidden, "email not verified")
|
||||
return
|
||||
}
|
||||
|
||||
h.lockout.ClearOnSuccess(r.Context(), req.Email)
|
||||
if err := h.issueSession(w, r, u); err != nil {
|
||||
h.logger.Error("issue session", "err", err, "user_id", u.ID)
|
||||
writeError(w, http.StatusInternalServerError, "failed to start session")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// checkRate consults the limiter (when one is configured) and writes a 429
|
||||
// response if the budget is exhausted. Returns false if the caller should
|
||||
// stop handling the request.
|
||||
func (h *authHandler) checkRate(w http.ResponseWriter, r *http.Request, name, key string, limit int, window time.Duration) bool {
|
||||
if h.limiter == nil || key == "" {
|
||||
return true
|
||||
}
|
||||
res, err := h.limiter.Allow(r.Context(), name, key, limit, window)
|
||||
if err != nil {
|
||||
h.logger.Warn("ratelimit error (failing open)", "rule", name, "err", err)
|
||||
return true
|
||||
}
|
||||
if !res.Allowed {
|
||||
retry := int(res.RetryAfter.Round(time.Second).Seconds())
|
||||
if retry < 1 {
|
||||
retry = 1
|
||||
}
|
||||
w.Header().Set("Retry-After", strconv.Itoa(retry))
|
||||
writeJSON(w, http.StatusTooManyRequests, map[string]any{
|
||||
"error": "rate limit exceeded",
|
||||
"retry_after": retry,
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// POST /auth/refresh
|
||||
func (h *authHandler) refresh(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(refreshCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing refresh token")
|
||||
return
|
||||
}
|
||||
oldHash := auth.HashOpaque(cookie.Value)
|
||||
|
||||
existing, err := h.refreshes.Get(r.Context(), oldHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrAuthTokenNotFound) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "invalid refresh token")
|
||||
return
|
||||
}
|
||||
h.logger.Error("lookup refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
if existing.RevokedAt != nil {
|
||||
// Replay of a revoked token. Revoke the family.
|
||||
_ = h.refreshes.RevokeAllForUser(r.Context(), existing.UserID)
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token reused")
|
||||
return
|
||||
}
|
||||
if time.Now().After(existing.ExpiresAt) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token expired")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.GetByID(r.Context(), existing.UserID)
|
||||
if err != nil {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
newRaw, newHash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
h.logger.Error("mint refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
exp := time.Now().Add(h.refreshTTL)
|
||||
if err := h.refreshes.Rotate(r.Context(), oldHash, storage.CreateRefreshTokenParams{
|
||||
Hash: newHash,
|
||||
UserID: u.ID,
|
||||
ExpiresAt: exp,
|
||||
UserAgent: r.UserAgent(),
|
||||
IPAddress: clientIP(r),
|
||||
}); err != nil {
|
||||
if errors.Is(err, domain.ErrRefreshTokenRevoked) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token reused")
|
||||
return
|
||||
}
|
||||
h.logger.Error("rotate refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
|
||||
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
|
||||
if err != nil {
|
||||
h.logger.Error("sign access", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
h.setRefreshCookie(w, newRaw, exp)
|
||||
writeJSON(w, http.StatusOK, authSuccess{
|
||||
AccessToken: access,
|
||||
ExpiresAt: accessExp,
|
||||
User: u,
|
||||
})
|
||||
}
|
||||
|
||||
// POST /auth/logout
|
||||
func (h *authHandler) logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(refreshCookieName)
|
||||
if err == nil && cookie.Value != "" {
|
||||
_ = h.refreshes.Revoke(r.Context(), auth.HashOpaque(cookie.Value))
|
||||
}
|
||||
h.clearRefreshCookie(w)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /auth/verify-email
|
||||
func (h *authHandler) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
||||
var req verifyEmailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token required")
|
||||
return
|
||||
}
|
||||
uid, err := h.verifications.Consume(r.Context(), auth.HashOpaque(req.Token))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrAuthTokenNotFound):
|
||||
writeError(w, http.StatusBadRequest, "invalid token")
|
||||
case errors.Is(err, domain.ErrAuthTokenConsumed):
|
||||
writeError(w, http.StatusBadRequest, "token already used")
|
||||
case errors.Is(err, domain.ErrAuthTokenExpired):
|
||||
writeError(w, http.StatusBadRequest, "token expired")
|
||||
default:
|
||||
h.logger.Error("consume verification", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "verification failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.users.MarkEmailVerified(r.Context(), uid); err != nil {
|
||||
h.logger.Error("mark verified", "err", err, "user_id", uid)
|
||||
writeError(w, http.StatusInternalServerError, "verification failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "verified"})
|
||||
}
|
||||
|
||||
// POST /auth/forgot-password
|
||||
func (h *authHandler) forgotPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req forgotPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if !h.checkRate(w, r, "forgot_password", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
|
||||
3, time.Hour) {
|
||||
return
|
||||
}
|
||||
// Always respond 202 to avoid leaking whether the email exists.
|
||||
defer func() { writeJSON(w, http.StatusAccepted, map[string]string{"status": "if_known_email_sent"}) }()
|
||||
|
||||
u, err := h.users.GetByEmail(r.Context(), req.Email)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
h.logger.Error("mint reset", "err", err)
|
||||
return
|
||||
}
|
||||
exp := time.Now().Add(h.passwordResetTTL)
|
||||
if err := h.resets.Create(r.Context(), u.ID, hash, exp); err != nil {
|
||||
h.logger.Error("persist reset", "err", err)
|
||||
return
|
||||
}
|
||||
link := h.publicBaseURL + "/reset-password/" + url.PathEscape(raw)
|
||||
if err := h.emails.SendPasswordReset(r.Context(), u.Email, u.Name, link); err != nil {
|
||||
h.logger.Error("send reset email", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// POST /auth/reset-password
|
||||
func (h *authHandler) resetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req resetPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token and new_password required")
|
||||
return
|
||||
}
|
||||
newHash, err := h.hasher.Hash(req.NewPassword)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("hash password", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
return
|
||||
}
|
||||
uid, err := h.resets.Consume(r.Context(), auth.HashOpaque(req.Token))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrAuthTokenNotFound):
|
||||
writeError(w, http.StatusBadRequest, "invalid token")
|
||||
case errors.Is(err, domain.ErrAuthTokenConsumed):
|
||||
writeError(w, http.StatusBadRequest, "token already used")
|
||||
case errors.Is(err, domain.ErrAuthTokenExpired):
|
||||
writeError(w, http.StatusBadRequest, "token expired")
|
||||
default:
|
||||
h.logger.Error("consume reset", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.users.UpdatePasswordHash(r.Context(), uid, newHash); err != nil {
|
||||
h.logger.Error("update password", "err", err, "user_id", uid)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
return
|
||||
}
|
||||
// Invalidate all existing sessions.
|
||||
_ = h.refreshes.RevokeAllForUser(r.Context(), uid)
|
||||
// Resetting the password is the canonical "unlock" path for the
|
||||
// account lockout that triggers after repeated bad-credential attempts.
|
||||
if u, err := h.users.GetByID(r.Context(), uid); err == nil {
|
||||
_ = h.lockout.ClearForUser(r.Context(), uid, u.Email)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "password_reset"})
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func (h *authHandler) sendVerificationEmail(ctx context.Context, u *domain.User) error {
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifications.Create(ctx, u.ID, hash, time.Now().Add(h.emailVerificationTTL)); err != nil {
|
||||
return err
|
||||
}
|
||||
link := h.publicBaseURL + "/verify-email?token=" + url.QueryEscape(raw)
|
||||
return h.emails.SendVerification(ctx, u.Email, u.Name, link)
|
||||
}
|
||||
|
||||
func (h *authHandler) issueSession(w http.ResponseWriter, r *http.Request, u *domain.User) error {
|
||||
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
refreshExp := time.Now().Add(h.refreshTTL)
|
||||
if err := h.refreshes.Create(r.Context(), storage.CreateRefreshTokenParams{
|
||||
Hash: hash,
|
||||
UserID: u.ID,
|
||||
ExpiresAt: refreshExp,
|
||||
UserAgent: r.UserAgent(),
|
||||
IPAddress: clientIP(r),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
h.setRefreshCookie(w, raw, refreshExp)
|
||||
writeJSON(w, http.StatusOK, authSuccess{
|
||||
AccessToken: access,
|
||||
ExpiresAt: accessExp,
|
||||
User: u,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *authHandler) setRefreshCookie(w http.ResponseWriter, value string, expires time.Time) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: refreshCookieName,
|
||||
Value: value,
|
||||
Path: "/auth",
|
||||
Domain: h.cookieDomain,
|
||||
Expires: expires,
|
||||
MaxAge: int(time.Until(expires).Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: h.cookieSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *authHandler) clearRefreshCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: refreshCookieName,
|
||||
Value: "",
|
||||
Path: "/auth",
|
||||
Domain: h.cookieDomain,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: h.cookieSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// --- requireAuth middleware ---
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const userIDCtxKey ctxKey = iota
|
||||
|
||||
func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(userIDCtxKey).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func requireAuth(signer *auth.JWTSigner) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(h, "Bearer ") {
|
||||
writeError(w, http.StatusUnauthorized, "missing bearer token")
|
||||
return
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(h, "Bearer "))
|
||||
claims, err := signer.Parse(raw)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrExpiredJWT) {
|
||||
writeError(w, http.StatusUnauthorized, "token expired")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid token")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userIDCtxKey, claims.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// hostFromContext returns the authed user's id, or writes 401 and returns
|
||||
// false. Used by host-facing handlers as the first line in the function.
|
||||
func hostFromContext(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthenticated")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
return uid, true
|
||||
}
|
||||
|
||||
// requireEventOwner fetches the event and confirms the authed user owns it.
|
||||
// On mismatch (or missing event) it returns 404 — never 403 — so a cross-
|
||||
// tenant probe cannot tell the difference between "event doesn't exist" and
|
||||
// "exists but belongs to someone else".
|
||||
func requireEventOwner(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
events *storage.EventRepo,
|
||||
eventID, hostID uuid.UUID,
|
||||
) (*domain.Event, bool) {
|
||||
ev, err := events.GetForHost(r.Context(), eventID, hostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return nil, false
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
return nil, false
|
||||
}
|
||||
return ev, true
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type billingHandler struct {
|
||||
logger *slog.Logger
|
||||
stripe *billing.Client
|
||||
users *storage.UserRepo
|
||||
subscriptions *storage.SubscriptionRepo
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
type checkoutSessionRequest struct {
|
||||
Tier string `json:"tier"`
|
||||
}
|
||||
|
||||
type checkoutSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// POST /billing/checkout-session — returns the Stripe Checkout URL the
|
||||
// frontend redirects the host to. Mints a Stripe customer on first use
|
||||
// and persists it so repeat calls reuse the same customer.
|
||||
func (h *billingHandler) checkoutSession(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.stripeEnabled(w) {
|
||||
return
|
||||
}
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req checkoutSessionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
tier := billing.Tier(strings.ToLower(req.Tier))
|
||||
if tier != billing.TierPro && tier != billing.TierBusiness {
|
||||
writeError(w, http.StatusBadRequest, "tier must be 'pro' or 'business'")
|
||||
return
|
||||
}
|
||||
price, err := h.stripe.PriceFor(tier)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "this tier is not configured yet — contact support")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.users.GetByID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
|
||||
existingCustomerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load billing record")
|
||||
return
|
||||
}
|
||||
customerID, err := h.stripe.CreateOrGetCustomer(hostID.String(), user.Email, user.Name, existingCustomerID)
|
||||
if err != nil {
|
||||
h.logger.Error("stripe customer", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe customer error")
|
||||
return
|
||||
}
|
||||
if existingCustomerID == "" {
|
||||
// First time — write a placeholder row so the customer id sticks.
|
||||
if _, err := h.subscriptions.Upsert(r.Context(), storage.UpsertParams{
|
||||
UserID: hostID,
|
||||
StripeCustomerID: customerID,
|
||||
}); err != nil {
|
||||
h.logger.Error("upsert sub placeholder", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
base := strings.TrimRight(h.publicBaseURL, "/")
|
||||
url, err := h.stripe.CreateCheckoutSession(billing.CheckoutSessionParams{
|
||||
CustomerID: customerID,
|
||||
PriceID: price,
|
||||
SuccessURL: base + "/dashboard?billing=success",
|
||||
CancelURL: base + "/dashboard?billing=cancelled",
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("stripe checkout session", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe checkout error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, checkoutSessionResponse{URL: url})
|
||||
}
|
||||
|
||||
type portalSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// POST /billing/portal — returns the customer portal URL so the user
|
||||
// can manage their payment method, view invoices, or cancel. 404 when
|
||||
// the user has no Stripe customer yet (they're still on free).
|
||||
func (h *billingHandler) portalSession(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.stripeEnabled(w) {
|
||||
return
|
||||
}
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
customerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load billing record")
|
||||
return
|
||||
}
|
||||
if customerID == "" {
|
||||
writeError(w, http.StatusNotFound, "no billing account yet — subscribe first")
|
||||
return
|
||||
}
|
||||
url, err := h.stripe.CreatePortalSession(customerID, strings.TrimRight(h.publicBaseURL, "/")+"/dashboard")
|
||||
if err != nil {
|
||||
h.logger.Error("stripe portal", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe portal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, portalSessionResponse{URL: url})
|
||||
}
|
||||
|
||||
type subscriptionStatusResponse struct {
|
||||
Tier string `json:"tier"`
|
||||
Status string `json:"status"`
|
||||
CurrentPeriodEnd string `json:"current_period_end,omitempty"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
Limits struct {
|
||||
EventsPerMonth int `json:"events_per_month"`
|
||||
GuestsPerEvent int `json:"guests_per_event"`
|
||||
} `json:"limits"`
|
||||
Usage struct {
|
||||
EventsThisMonth int `json:"events_this_month"`
|
||||
} `json:"usage"`
|
||||
PortalAvailable bool `json:"portal_available"`
|
||||
}
|
||||
|
||||
// GET /billing/status — returns the host's current tier + limits +
|
||||
// usage. The frontend uses this to render the billing page and the
|
||||
// 402-modal copy ("you used X of Y events this month").
|
||||
func (h *billingHandler) status(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tier := billing.TierFree
|
||||
status := "active"
|
||||
var periodEnd string
|
||||
cancelAtPeriodEnd := false
|
||||
portalAvailable := false
|
||||
|
||||
if h.subscriptions != nil {
|
||||
sub, err := h.subscriptions.GetActiveByUser(r.Context(), hostID)
|
||||
switch {
|
||||
case err == nil:
|
||||
tier = billing.Tier(sub.Tier)
|
||||
status = sub.Status
|
||||
cancelAtPeriodEnd = sub.CancelAtPeriodEnd
|
||||
if sub.CurrentPeriodEnd != nil {
|
||||
periodEnd = sub.CurrentPeriodEnd.Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
portalAvailable = sub.StripeCustomerID != ""
|
||||
case errors.Is(err, storage.ErrSubscriptionNotFound):
|
||||
// Free tier — leave defaults.
|
||||
default:
|
||||
writeError(w, http.StatusInternalServerError, "failed to load subscription")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limits := billing.LimitsFor(tier)
|
||||
events, _ := h.subscriptions.CountEventsInCurrentMonth(r.Context(), hostID)
|
||||
|
||||
resp := subscriptionStatusResponse{
|
||||
Tier: string(tier),
|
||||
Status: status,
|
||||
CurrentPeriodEnd: periodEnd,
|
||||
CancelAtPeriodEnd: cancelAtPeriodEnd,
|
||||
PortalAvailable: portalAvailable,
|
||||
}
|
||||
resp.Limits.EventsPerMonth = limits.EventsPerMonth
|
||||
resp.Limits.GuestsPerEvent = limits.GuestsPerEvent
|
||||
resp.Usage.EventsThisMonth = events
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// stripeEnabled returns true if the billing client is configured, else
|
||||
// writes 503 and returns false. The /billing/status path skips this so
|
||||
// the frontend can render a "free tier" page in dev environments.
|
||||
func (h *billingHandler) stripeEnabled(w http.ResponseWriter) bool {
|
||||
if h.stripe == nil || !h.stripe.Enabled() {
|
||||
writeError(w, http.StatusServiceUnavailable, "billing is not configured on this instance")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// tierEnforcer wraps the SubscriptionRepo with policy decisions. Lives
|
||||
// here (not in storage) because the policy is HTTP-shaped: we map the
|
||||
// outcome to 402 + an upgrade URL.
|
||||
type tierEnforcer struct {
|
||||
subs *storage.SubscriptionRepo
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
func newTierEnforcer(subs *storage.SubscriptionRepo, publicBaseURL string) *tierEnforcer {
|
||||
return &tierEnforcer{subs: subs, publicBaseURL: publicBaseURL}
|
||||
}
|
||||
|
||||
// currentTier returns the host's effective tier. ErrSubscriptionNotFound
|
||||
// means "no granting subscription on file" → free. Other DB errors
|
||||
// bubble up.
|
||||
func (e *tierEnforcer) currentTier(ctx context.Context, hostID uuid.UUID) (billing.Tier, error) {
|
||||
if e == nil || e.subs == nil {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
sub, err := e.subs.GetActiveByUser(ctx, hostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, storage.ErrSubscriptionNotFound) {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if !billing.StatusGrantsAccess(sub.Status) {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
tier := billing.Tier(sub.Tier)
|
||||
if !tier.Valid() {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
return tier, nil
|
||||
}
|
||||
|
||||
// allowEventCreate verifies the host's monthly event budget. Returns
|
||||
// true when the request may proceed. On denial it writes a 402 with the
|
||||
// upgrade hint and returns false.
|
||||
func (e *tierEnforcer) allowEventCreate(w http.ResponseWriter, r *http.Request, hostID uuid.UUID) bool {
|
||||
if e == nil || e.subs == nil {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).EventsPerMonth
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountEventsInCurrentMonth(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count events")
|
||||
return false
|
||||
}
|
||||
if used >= limit {
|
||||
e.writePaymentRequired(w, "events_per_month", tier, used, limit,
|
||||
"You've reached your monthly event limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// allowGuestCreate verifies the per-event guest cap. Same shape as
|
||||
// allowEventCreate.
|
||||
func (e *tierEnforcer) allowGuestCreate(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID) bool {
|
||||
if e == nil || e.subs == nil {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).GuestsPerEvent
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count guests")
|
||||
return false
|
||||
}
|
||||
if used >= limit {
|
||||
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
|
||||
"This event has reached the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// allowGuestImport is the CSV-import variant: check the cap against
|
||||
// existing + incoming row count up-front, before we start the
|
||||
// transaction. Dedup may shrink the actual insert count later — that's
|
||||
// OK, we just stay on the safe side.
|
||||
func (e *tierEnforcer) allowGuestImport(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID, incoming int) bool {
|
||||
if e == nil || e.subs == nil || incoming == 0 {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).GuestsPerEvent
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count guests")
|
||||
return false
|
||||
}
|
||||
if used+incoming > limit {
|
||||
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
|
||||
"This import would exceed the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type paymentRequiredBody struct {
|
||||
Error string `json:"error"`
|
||||
Reason string `json:"reason"`
|
||||
Tier string `json:"tier"`
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
UpgradeURL string `json:"upgrade_url"`
|
||||
}
|
||||
|
||||
func (e *tierEnforcer) writePaymentRequired(w http.ResponseWriter, reason string, tier billing.Tier, used, limit int, msg string) {
|
||||
body := paymentRequiredBody{
|
||||
Error: msg,
|
||||
Reason: reason,
|
||||
Tier: string(tier),
|
||||
Used: used,
|
||||
Limit: limit,
|
||||
UpgradeURL: strings.TrimRight(e.publicBaseURL, "/") + "/dashboard/billing",
|
||||
}
|
||||
writeJSON(w, http.StatusPaymentRequired, body)
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/csvimport"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
// 1 MB cap on uploads. With ~200 bytes per row that's ~5,000 guests —
|
||||
// matches the row cap in csvimport.DefaultMaxRows.
|
||||
csvMaxBytes = 1 << 20
|
||||
)
|
||||
|
||||
type csvImportHandler struct {
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type importResponse struct {
|
||||
Added int `json:"added"`
|
||||
Skipped int `json:"skipped"`
|
||||
SkippedEmails []string `json:"skipped_emails,omitempty"`
|
||||
Errors []csvimport.RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"`
|
||||
}
|
||||
|
||||
type previewResponse struct {
|
||||
Rows []csvimport.Row `json:"rows"`
|
||||
Errors []csvimport.RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/import/preview — parse + validate but don't write.
|
||||
// Used by the frontend to show a "is this what you meant?" table before commit.
|
||||
func (h *csvImportHandler) preview(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
body, ok := readCSVUpload(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
res, err := csvimport.Parse(body, csvimport.Options{})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, previewResponse{
|
||||
Rows: res.Rows,
|
||||
Errors: res.Errors,
|
||||
TotalCount: res.TotalCount,
|
||||
})
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/import — parse, validate, and commit valid rows
|
||||
// in a single transaction. Rows with row-level errors are reported back
|
||||
// but don't prevent the rest from importing.
|
||||
func (h *csvImportHandler) commit(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
body, ok := readCSVUpload(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
parsed, err := csvimport.Parse(body, csvimport.Options{})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Plan enforcement: prevent an import that, even if perfectly
|
||||
// dedup-free, would exceed the per-event guest cap. We check upfront
|
||||
// (current count + parsed-row count) and reject with 402 if it'd
|
||||
// overflow. False positives on dedup-heavy CSVs are acceptable —
|
||||
// host can dedupe and re-upload.
|
||||
if !h.enforcer.allowGuestImport(w, r, hostID, eventID, len(parsed.Rows)) {
|
||||
return
|
||||
}
|
||||
|
||||
rows := make([]storage.BulkImportRow, 0, len(parsed.Rows))
|
||||
for _, r := range parsed.Rows {
|
||||
rows = append(rows, storage.BulkImportRow{
|
||||
Name: r.Name,
|
||||
Email: r.Email,
|
||||
Phone: r.Phone,
|
||||
PlusOnes: r.PlusOnes,
|
||||
})
|
||||
}
|
||||
|
||||
res, err := h.guests.BulkImportGuests(r.Context(), eventID, rows)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "import failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, importResponse{
|
||||
Added: res.Added,
|
||||
Skipped: res.Skipped,
|
||||
SkippedEmails: res.SkippedEmails,
|
||||
Errors: parsed.Errors,
|
||||
TotalCount: parsed.TotalCount,
|
||||
})
|
||||
}
|
||||
|
||||
// GET /events/{id}/guests/import/template — download a sample CSV. Auth is
|
||||
// applied at the route level; ownership is verified so an attacker can't
|
||||
// probe the existence of an event by hitting this endpoint.
|
||||
func (h *csvImportHandler) template(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-import-template.csv"`)
|
||||
_, _ = w.Write([]byte(csvimport.TemplateCSV))
|
||||
}
|
||||
|
||||
// readCSVUpload returns the multipart file body (capped at csvMaxBytes) or
|
||||
// writes an error and returns (nil, false). Accepted shapes:
|
||||
//
|
||||
// - multipart/form-data with a "file" field (the drag-drop UI uses this)
|
||||
// - any other Content-Type — the raw body is treated as the CSV (curl-friendly)
|
||||
func readCSVUpload(w http.ResponseWriter, r *http.Request) (io.ReadCloser, bool) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, csvMaxBytes)
|
||||
if err := r.ParseMultipartForm(csvMaxBytes); err == nil && r.MultipartForm != nil {
|
||||
files := r.MultipartForm.File["file"]
|
||||
if len(files) == 0 {
|
||||
writeError(w, http.StatusBadRequest, `form field "file" is required`)
|
||||
return nil, false
|
||||
}
|
||||
f, err := files[0].Open()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "cannot read uploaded file")
|
||||
return nil, false
|
||||
}
|
||||
return f, true
|
||||
} else if err != nil && !errors.Is(err, http.ErrNotMultipart) {
|
||||
writeError(w, http.StatusBadRequest, "invalid multipart body")
|
||||
return nil, false
|
||||
}
|
||||
// Fall through: raw body as CSV.
|
||||
return r.Body, true
|
||||
}
|
||||
+39
-28
@@ -15,11 +15,11 @@ import (
|
||||
)
|
||||
|
||||
type eventHandler struct {
|
||||
repo *storage.EventRepo
|
||||
repo *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type createEventRequest struct {
|
||||
HostID string `json:"host_id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
EventDate time.Time `json:"event_date"`
|
||||
@@ -32,6 +32,14 @@ type createEventRequest struct {
|
||||
var slugRe = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
|
||||
|
||||
func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !h.enforcer.allowEventCreate(w, r, hostID) {
|
||||
return
|
||||
}
|
||||
|
||||
var req createEventRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
@@ -51,12 +59,6 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hostID, err := uuid.Parse(req.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
|
||||
return
|
||||
}
|
||||
|
||||
status := domain.EventStatus(req.Status)
|
||||
if status == "" {
|
||||
status = domain.EventStatusDraft
|
||||
@@ -89,37 +91,31 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *eventHandler) get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ev, err := h.repo.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
ev, ok := requireEventOwner(w, r, h.repo, id, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ev)
|
||||
}
|
||||
|
||||
func (h *eventHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
limit := atoiOr(q.Get("limit"), 50)
|
||||
offset := atoiOr(q.Get("offset"), 0)
|
||||
|
||||
var hostID uuid.UUID
|
||||
if v := q.Get("host_id"); v != "" {
|
||||
parsed, err := uuid.Parse(v)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
|
||||
return
|
||||
}
|
||||
hostID = parsed
|
||||
}
|
||||
|
||||
events, err := h.repo.List(r.Context(), hostID, limit, offset)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to list events")
|
||||
@@ -146,6 +142,10 @@ type updateEventRequest struct {
|
||||
}
|
||||
|
||||
func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
@@ -180,7 +180,7 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
params.Status = &s
|
||||
}
|
||||
|
||||
ev, err := h.repo.Update(r.Context(), id, params)
|
||||
ev, err := h.repo.Update(r.Context(), id, hostID, params)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrEventNotFound):
|
||||
@@ -196,11 +196,15 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.repo.Delete(r.Context(), id); err != nil {
|
||||
if err := h.repo.Delete(r.Context(), id, hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
@@ -212,7 +216,14 @@ func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func parseIDParam(w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, bool) {
|
||||
raw := r.PathValue(name)
|
||||
return parseRawUUID(w, name, r.PathValue(name))
|
||||
}
|
||||
|
||||
func parseRawUUID(w http.ResponseWriter, name, raw string) (uuid.UUID, bool) {
|
||||
if raw == "" {
|
||||
writeError(w, http.StatusBadRequest, name+" is required")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, name+" must be a valid uuid")
|
||||
|
||||
+107
-9
@@ -10,8 +10,9 @@ import (
|
||||
)
|
||||
|
||||
type guestHandler struct {
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type createGuestRequest struct {
|
||||
@@ -24,16 +25,18 @@ type createGuestRequest struct {
|
||||
}
|
||||
|
||||
func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := h.events.Get(r.Context(), eventID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
if !h.enforcer.allowGuestCreate(w, r, hostID, eventID) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -67,11 +70,106 @@ func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, g)
|
||||
}
|
||||
|
||||
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
type updateGuestRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Email *string `json:"email"`
|
||||
Phone *string `json:"phone"`
|
||||
PlusOnes *int `json:"plus_ones"`
|
||||
}
|
||||
|
||||
// PATCH /events/{id}/guests/{guest_id} — patch a guest's contact info.
|
||||
// Fields omitted from the body are left untouched. Empty strings for
|
||||
// email/phone clear those columns to NULL.
|
||||
func (h *guestHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req updateGuestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Name != nil && *req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name cannot be empty")
|
||||
return
|
||||
}
|
||||
if req.PlusOnes != nil && *req.PlusOnes < 0 {
|
||||
writeError(w, http.StatusBadRequest, "plus_ones must be >= 0")
|
||||
return
|
||||
}
|
||||
|
||||
g, err := h.guests.Update(r.Context(), eventID, guestID, storage.UpdateGuestParams{
|
||||
Name: req.Name,
|
||||
Email: req.Email,
|
||||
Phone: req.Phone,
|
||||
PlusOnes: req.PlusOnes,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to update guest")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, g)
|
||||
}
|
||||
|
||||
// DELETE /events/{id}/guests/{guest_id} — remove a guest from an event.
|
||||
// Cascade-deletes their token, rsvp, access logs, notifications.
|
||||
func (h *guestHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
if err := h.guests.Delete(r.Context(), eventID, guestID); err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete guest")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
limit := atoiOr(q.Get("limit"), 100)
|
||||
offset := atoiOr(q.Get("offset"), 0)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type meHandler struct {
|
||||
users *storage.UserRepo
|
||||
}
|
||||
|
||||
// GET /me — returns the authenticated user. Used by the frontend to bootstrap
|
||||
// after a page reload (with a fresh access token from /auth/refresh).
|
||||
func (h *meHandler) get(w http.ResponseWriter, r *http.Request) {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthenticated")
|
||||
return
|
||||
}
|
||||
u, err := h.users.GetByID(r.Context(), uid)
|
||||
if err != nil {
|
||||
if err == domain.ErrUserNotFound {
|
||||
writeError(w, http.StatusUnauthorized, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, u)
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// privacyHandler holds the GDPR-style "your data, your choice" endpoints:
|
||||
// data export, account deletion, and terms-acceptance recording.
|
||||
type privacyHandler struct {
|
||||
logger *slog.Logger
|
||||
users *storage.UserRepo
|
||||
events *storage.EventRepo
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
rsvps *storage.RSVPRepo
|
||||
access *storage.AccessLogRepo
|
||||
notifs *storage.DB // raw pool access for the export queries
|
||||
refresh *storage.RefreshTokenRepo
|
||||
}
|
||||
|
||||
// DataExport is the shape of the JSON the host downloads from
|
||||
// GET /me/data-export. We don't paginate or stream — for the scale
|
||||
// GuestGuard hosts have, a single response is reasonable. If a host
|
||||
// ever has 100k+ access logs we'll switch to async + email-a-link.
|
||||
type DataExport struct {
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Format string `json:"format"`
|
||||
User *domain.User `json:"user"`
|
||||
Events []*domain.Event `json:"events"`
|
||||
Guests []*domain.Guest `json:"guests"`
|
||||
Tokens []exportedToken `json:"tokens"`
|
||||
RSVPs []exportedRSVP `json:"rsvps"`
|
||||
AccessLogs []exportedAccess `json:"access_logs"`
|
||||
Notifs []exportedNotif `json:"notifications"`
|
||||
}
|
||||
|
||||
type exportedToken struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type exportedRSVP struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
Response string `json:"response"`
|
||||
PlusOnes int `json:"plus_ones"`
|
||||
SubmittedAt time.Time `json:"submitted_at"`
|
||||
}
|
||||
type exportedAccess struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
RiskScore *int `json:"risk_score,omitempty"`
|
||||
Flagged bool `json:"flagged"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type exportedNotif struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
Channel string `json:"channel"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// GET /me/data-export — returns every record the system holds about the
|
||||
// authenticated user. The Content-Disposition header makes browsers
|
||||
// offer a download rather than rendering inline.
|
||||
func (h *privacyHandler) dataExport(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.users.GetByID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
|
||||
export := DataExport{
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Format: "guestguard.v1",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Events the user hosts.
|
||||
events, err := h.events.List(r.Context(), hostID, 1000, 0)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load events")
|
||||
return
|
||||
}
|
||||
export.Events = events
|
||||
|
||||
// For each event, pull guests + tokens + rsvps + access_logs + notifications.
|
||||
for _, ev := range events {
|
||||
guests, err := h.guests.ListByEvent(r.Context(), ev.ID, 5000, 0)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guests")
|
||||
return
|
||||
}
|
||||
export.Guests = append(export.Guests, guests...)
|
||||
|
||||
for _, g := range guests {
|
||||
// Token (at most one per guest, but query as a list for symmetry).
|
||||
if err := h.appendTokens(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: tokens", "err", err)
|
||||
}
|
||||
if err := h.appendRSVPs(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: rsvps", "err", err)
|
||||
}
|
||||
if err := h.appendAccess(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: access", "err", err)
|
||||
}
|
||||
if err := h.appendNotifs(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: notifs", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-data-export.json"`)
|
||||
writeJSON(w, http.StatusOK, export)
|
||||
}
|
||||
|
||||
func (h *privacyHandler) appendTokens(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, expires_at, status, created_at
|
||||
FROM tokens WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var t exportedToken
|
||||
if err := rows.Scan(&t.ID, &t.GuestID, &t.ExpiresAt, &t.Status, &t.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Tokens = append(out.Tokens, t)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendRSVPs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, response::text, plus_ones, submitted_at
|
||||
FROM rsvps WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var r exportedRSVP
|
||||
if err := rows.Scan(&r.ID, &r.GuestID, &r.Response, &r.PlusOnes, &r.SubmittedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.RSVPs = append(out.RSVPs, r)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendAccess(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, risk_score, flagged, created_at
|
||||
FROM access_logs WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var a exportedAccess
|
||||
var rs *int
|
||||
if err := rows.Scan(&a.ID, &a.GuestID, &rs, &a.Flagged, &a.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
a.RiskScore = rs
|
||||
out.AccessLogs = append(out.AccessLogs, a)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendNotifs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, channel::text, type::text, status::text, created_at
|
||||
FROM notifications WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var n exportedNotif
|
||||
if err := rows.Scan(&n.ID, &n.GuestID, &n.Channel, &n.Type, &n.Status, &n.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Notifs = append(out.Notifs, n)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// DELETE /me — soft-deletes the host's account. All sessions are
|
||||
// revoked immediately. A hard delete happens via a separate cron 30
|
||||
// days later (TBD ops work). The user is logged out from all devices
|
||||
// as a side effect of revoking the refresh tokens.
|
||||
func (h *privacyHandler) deleteMe(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.users.SoftDelete(r.Context(), hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
writeError(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete account")
|
||||
return
|
||||
}
|
||||
// Best-effort: revoke refresh tokens so other sessions log out too.
|
||||
// Failure here is logged but doesn't roll back the soft-delete — the
|
||||
// access tokens (JWT) will still expire on their own ~15 minute TTL.
|
||||
if err := h.refresh.RevokeAllForUser(r.Context(), hostID); err != nil {
|
||||
h.logger.Warn("delete-me: revoke refresh tokens", "err", err, "user_id", hostID)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /me/accept-terms — records that the authenticated user accepts
|
||||
// the current ToS + privacy policy. Idempotent. Used by both the
|
||||
// onboarding gate (existing accounts created before T&C were enforced)
|
||||
// and any future "we updated our terms" re-acceptance flow.
|
||||
func (h *privacyHandler) acceptTerms(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.users.AcceptTerms(r.Context(), hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
writeError(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to record acceptance")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "accepted"})
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ipKey is the rate-limit key for endpoints scoped by source IP only
|
||||
// (e.g. POST /auth/signup). XFF/X-Real-IP are honoured because in the
|
||||
// homelab the API sits behind Traefik.
|
||||
func ipKey(r *http.Request) string {
|
||||
return clientIP(r)
|
||||
}
|
||||
|
||||
// pathKey returns a path-parameter as the rate-limit key — used for the
|
||||
// token-scoped endpoints so an attacker brute-forcing a single token is
|
||||
// limited regardless of the IPs they rotate through.
|
||||
func pathKey(name string) KeyFunc {
|
||||
return func(r *http.Request) string {
|
||||
return r.PathValue(name)
|
||||
}
|
||||
}
|
||||
|
||||
// userIDKey extracts the authenticated user id from the request context.
|
||||
// Returns "" when the route isn't behind requireAuth, in which case the
|
||||
// middleware bypasses (fail-open) — the route's own auth layer handles
|
||||
// rejection.
|
||||
func userIDKey(r *http.Request) string {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return uid.String()
|
||||
}
|
||||
|
||||
// KeyFunc mirrors ratelimit.KeyFunc so call sites don't have to import the
|
||||
// inner package.
|
||||
type KeyFunc = func(r *http.Request) string
|
||||
+264
-43
@@ -5,63 +5,167 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
"github.com/alchemistkay/guestguard/internal/ratelimit"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
logger *slog.Logger
|
||||
db *storage.DB
|
||||
hub *Hub
|
||||
users *userHandler
|
||||
events *eventHandler
|
||||
guests *guestHandler
|
||||
tokens *tokenHandler
|
||||
rsvps *rsvpHandler
|
||||
activity *activityHandler
|
||||
ws *wsHandler
|
||||
health *healthHandler
|
||||
logger *slog.Logger
|
||||
db *storage.DB
|
||||
hub *Hub
|
||||
authH *authHandler
|
||||
me *meHandler
|
||||
events *eventHandler
|
||||
guests *guestHandler
|
||||
tokens *tokenHandler
|
||||
rsvps *rsvpHandler
|
||||
activity *activityHandler
|
||||
ws *wsHandler
|
||||
wsTicket *wsTicketHandler
|
||||
health *healthHandler
|
||||
signer *auth.JWTSigner
|
||||
limiter *ratelimit.Limiter
|
||||
unsub *unsubscribeHandler
|
||||
webhooks *webhookHandler
|
||||
csv *csvImportHandler
|
||||
billing *billingHandler
|
||||
stripeWH *stripeWebhookHandler
|
||||
privacy *privacyHandler
|
||||
}
|
||||
|
||||
type ServerDeps struct {
|
||||
Logger *slog.Logger
|
||||
DB *storage.DB
|
||||
Hub *Hub
|
||||
AccessPublisher accessPublisher
|
||||
RSVPPublisher rsvpPublisher
|
||||
FraudScorer fraudScorer
|
||||
TokenTTL time.Duration
|
||||
AccessPublisher accessPublisher
|
||||
RSVPPublisher rsvpPublisher
|
||||
InvitationPublisher invitationPublisher
|
||||
FraudScorer fraudScorer
|
||||
TokenTTL time.Duration
|
||||
|
||||
// Auth
|
||||
JWTSecret string
|
||||
JWTIssuer string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
PublicBaseURL string
|
||||
RefreshCookieDomain string
|
||||
RefreshCookieSecure bool
|
||||
EmailSender auth.EmailSender
|
||||
WSTicketTTL time.Duration
|
||||
|
||||
// Rate limiting / abuse controls
|
||||
Redis *redis.Client
|
||||
LoginLockoutMax int // failed attempts before account lockout (default 5)
|
||||
LoginFailWindow time.Duration // counter TTL (default 15 min)
|
||||
|
||||
// Notifications / unsubscribe
|
||||
NotificationRepo *notification.Repo
|
||||
SuppressionRepo *notification.SuppressionRepo
|
||||
UnsubscribeSigner *notification.UnsubscribeSigner
|
||||
|
||||
// Billing (Block F). Nil StripeClient leaves billing disabled — the
|
||||
// system still boots and runs, all users sit on the free tier with
|
||||
// its limits enforced; /billing/* returns 503.
|
||||
StripeClient *billing.Client
|
||||
}
|
||||
|
||||
func NewServer(deps ServerDeps) *Server {
|
||||
func NewServer(deps ServerDeps) (*Server, error) {
|
||||
eventRepo := storage.NewEventRepo(deps.DB)
|
||||
guestRepo := storage.NewGuestRepo(deps.DB)
|
||||
tokenRepo := storage.NewTokenRepo(deps.DB)
|
||||
rsvpRepo := storage.NewRSVPRepo(deps.DB)
|
||||
accessRepo := storage.NewAccessLogRepo(deps.DB)
|
||||
userRepo := storage.NewUserRepo(deps.DB)
|
||||
verifRepo := storage.NewEmailVerificationRepo(deps.DB)
|
||||
resetRepo := storage.NewPasswordResetRepo(deps.DB)
|
||||
refreshRepo := storage.NewRefreshTokenRepo(deps.DB)
|
||||
subRepo := storage.NewSubscriptionRepo(deps.DB)
|
||||
enforcer := newTierEnforcer(subRepo, deps.PublicBaseURL)
|
||||
|
||||
signer, err := auth.NewJWTSigner(deps.JWTSecret, deps.AccessTokenTTL, deps.JWTIssuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hasher := auth.NewPasswordHasher()
|
||||
|
||||
emails := deps.EmailSender
|
||||
if emails == nil {
|
||||
emails = auth.LogEmailSender{Logger: deps.Logger}
|
||||
}
|
||||
|
||||
hub := deps.Hub
|
||||
if hub == nil {
|
||||
hub = NewHub(deps.Logger)
|
||||
}
|
||||
|
||||
wsTicketTTL := deps.WSTicketTTL
|
||||
if wsTicketTTL <= 0 {
|
||||
wsTicketTTL = 60 * time.Second
|
||||
}
|
||||
wsTickets := newWSTicketStore(wsTicketTTL)
|
||||
|
||||
var limiter *ratelimit.Limiter
|
||||
var lockout *auth.LockoutTracker
|
||||
if deps.Redis != nil {
|
||||
limiter = ratelimit.New(deps.Redis, "gg:rl")
|
||||
lockoutMax := deps.LoginLockoutMax
|
||||
if lockoutMax <= 0 {
|
||||
lockoutMax = 5
|
||||
}
|
||||
failWindow := deps.LoginFailWindow
|
||||
if failWindow <= 0 {
|
||||
failWindow = 15 * time.Minute
|
||||
}
|
||||
lockout = auth.NewLockoutTracker(deps.Redis, lockoutMax, failWindow)
|
||||
}
|
||||
|
||||
authH := newAuthHandler(authHandlerDeps{
|
||||
Logger: deps.Logger,
|
||||
Users: userRepo,
|
||||
Verifications: verifRepo,
|
||||
Resets: resetRepo,
|
||||
Refreshes: refreshRepo,
|
||||
Hasher: hasher,
|
||||
Signer: signer,
|
||||
Emails: emails,
|
||||
Lockout: lockout,
|
||||
Limiter: limiter,
|
||||
PublicBaseURL: deps.PublicBaseURL,
|
||||
EmailVerificationTTL: deps.EmailVerificationTTL,
|
||||
PasswordResetTTL: deps.PasswordResetTTL,
|
||||
RefreshTTL: deps.RefreshTokenTTL,
|
||||
CookieDomain: deps.RefreshCookieDomain,
|
||||
CookieSecure: deps.RefreshCookieSecure,
|
||||
})
|
||||
|
||||
return &Server{
|
||||
logger: deps.Logger,
|
||||
db: deps.DB,
|
||||
hub: hub,
|
||||
users: &userHandler{repo: userRepo},
|
||||
events: &eventHandler{repo: eventRepo},
|
||||
guests: &guestHandler{guests: guestRepo, events: eventRepo},
|
||||
authH: authH,
|
||||
me: &meHandler{users: userRepo},
|
||||
events: &eventHandler{repo: eventRepo, enforcer: enforcer},
|
||||
guests: &guestHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
|
||||
tokens: &tokenHandler{
|
||||
logger: deps.Logger,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
events: eventRepo,
|
||||
accessLogs: accessRepo,
|
||||
gen: auth.NewGenerator(),
|
||||
ttl: deps.TokenTTL,
|
||||
pub: deps.AccessPublisher,
|
||||
logger: deps.Logger,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
events: eventRepo,
|
||||
users: userRepo,
|
||||
accessLogs: accessRepo,
|
||||
gen: auth.NewGenerator(),
|
||||
ttl: deps.TokenTTL,
|
||||
pub: deps.AccessPublisher,
|
||||
invitations: deps.InvitationPublisher,
|
||||
publicBaseURL: deps.PublicBaseURL,
|
||||
},
|
||||
rsvps: &rsvpHandler{
|
||||
logger: deps.Logger,
|
||||
@@ -78,9 +182,46 @@ func NewServer(deps ServerDeps) *Server {
|
||||
rsvps: rsvpRepo,
|
||||
accessLogs: accessRepo,
|
||||
},
|
||||
ws: &wsHandler{logger: deps.Logger, hub: hub},
|
||||
health: &healthHandler{pool: deps.DB.Pool},
|
||||
}
|
||||
ws: &wsHandler{logger: deps.Logger, hub: hub, tickets: wsTickets},
|
||||
wsTicket: &wsTicketHandler{tickets: wsTickets, events: eventRepo},
|
||||
health: &healthHandler{pool: deps.DB.Pool},
|
||||
signer: signer,
|
||||
limiter: limiter,
|
||||
unsub: &unsubscribeHandler{
|
||||
logger: deps.Logger,
|
||||
signer: deps.UnsubscribeSigner,
|
||||
suppress: deps.SuppressionRepo,
|
||||
},
|
||||
webhooks: &webhookHandler{
|
||||
logger: deps.Logger,
|
||||
notifs: deps.NotificationRepo,
|
||||
suppress: deps.SuppressionRepo,
|
||||
},
|
||||
csv: &csvImportHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
|
||||
billing: &billingHandler{
|
||||
logger: deps.Logger,
|
||||
stripe: deps.StripeClient,
|
||||
users: userRepo,
|
||||
subscriptions: subRepo,
|
||||
publicBaseURL: deps.PublicBaseURL,
|
||||
},
|
||||
stripeWH: &stripeWebhookHandler{
|
||||
logger: deps.Logger,
|
||||
stripe: deps.StripeClient,
|
||||
subs: subRepo,
|
||||
},
|
||||
privacy: &privacyHandler{
|
||||
logger: deps.Logger,
|
||||
users: userRepo,
|
||||
events: eventRepo,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
rsvps: rsvpRepo,
|
||||
access: accessRepo,
|
||||
notifs: deps.DB,
|
||||
refresh: refreshRepo,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Hub() *Hub { return s.hub }
|
||||
@@ -91,25 +232,104 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.HandleFunc("GET /health", s.health.live)
|
||||
mux.HandleFunc("GET /health/ready", s.health.ready)
|
||||
|
||||
mux.HandleFunc("POST /users", s.users.upsert)
|
||||
// Per-route rate limiters (no-op when Redis isn't wired).
|
||||
authed := requireAuth(s.signer)
|
||||
rl := func(name string, limit int, window time.Duration, keyFn KeyFunc, h http.Handler) http.Handler {
|
||||
if s.limiter == nil {
|
||||
return h
|
||||
}
|
||||
return s.limiter.Middleware(
|
||||
ratelimit.Rule{Name: name, Limit: limit, Window: window},
|
||||
keyFn,
|
||||
s.logger,
|
||||
)(h)
|
||||
}
|
||||
|
||||
mux.HandleFunc("POST /events", s.events.create)
|
||||
mux.HandleFunc("GET /events", s.events.list)
|
||||
mux.HandleFunc("GET /events/{id}", s.events.get)
|
||||
mux.HandleFunc("PATCH /events/{id}", s.events.update)
|
||||
mux.HandleFunc("DELETE /events/{id}", s.events.delete)
|
||||
// Anonymous auth endpoints — POST /auth/login + /auth/forgot-password
|
||||
// rate-limit inside the handler (key includes the email body field).
|
||||
mux.Handle("POST /auth/signup",
|
||||
rl("auth_signup", 5, time.Hour, ipKey, http.HandlerFunc(s.authH.signup)))
|
||||
mux.HandleFunc("POST /auth/login", s.authH.login)
|
||||
mux.HandleFunc("POST /auth/refresh", s.authH.refresh)
|
||||
mux.HandleFunc("POST /auth/logout", s.authH.logout)
|
||||
mux.HandleFunc("POST /auth/verify-email", s.authH.verifyEmail)
|
||||
mux.HandleFunc("POST /auth/forgot-password", s.authH.forgotPassword)
|
||||
mux.HandleFunc("POST /auth/reset-password", s.authH.resetPassword)
|
||||
|
||||
mux.HandleFunc("POST /events/{id}/guests", s.guests.create)
|
||||
mux.HandleFunc("GET /events/{id}/guests", s.guests.list)
|
||||
mux.Handle("GET /me", authed(http.HandlerFunc(s.me.get)))
|
||||
mux.Handle("POST /auth/ws-ticket", authed(http.HandlerFunc(s.wsTicket.issue)))
|
||||
|
||||
mux.HandleFunc("GET /events/{id}/activity", s.activity.list)
|
||||
// Privacy / GDPR-style endpoints — host can export their data,
|
||||
// delete their account, and record terms acceptance from the
|
||||
// onboarding gate.
|
||||
mux.Handle("GET /me/data-export", authed(http.HandlerFunc(s.privacy.dataExport)))
|
||||
mux.Handle("DELETE /me", authed(http.HandlerFunc(s.privacy.deleteMe)))
|
||||
mux.Handle("POST /me/accept-terms", authed(http.HandlerFunc(s.privacy.acceptTerms)))
|
||||
|
||||
mux.HandleFunc("POST /events/{id}/guests/{guest_id}/tokens", s.tokens.issue)
|
||||
mux.HandleFunc("GET /access/{token}", s.tokens.access)
|
||||
mux.HandleFunc("POST /rsvp/{token}", s.rsvps.submit)
|
||||
// Host-facing event/guest/token writes are limited by user_id.
|
||||
mux.Handle("POST /events",
|
||||
authed(rl("events_create", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.events.create))))
|
||||
mux.Handle("GET /events", authed(http.HandlerFunc(s.events.list)))
|
||||
mux.Handle("GET /events/{id}", authed(http.HandlerFunc(s.events.get)))
|
||||
mux.Handle("PATCH /events/{id}", authed(http.HandlerFunc(s.events.update)))
|
||||
mux.Handle("DELETE /events/{id}", authed(http.HandlerFunc(s.events.delete)))
|
||||
|
||||
mux.Handle("POST /events/{id}/guests",
|
||||
authed(rl("guests_create", 1000, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.create))))
|
||||
mux.Handle("GET /events/{id}/guests", authed(http.HandlerFunc(s.guests.list)))
|
||||
mux.Handle("PATCH /events/{id}/guests/{guest_id}",
|
||||
authed(rl("guests_update", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.update))))
|
||||
mux.Handle("DELETE /events/{id}/guests/{guest_id}",
|
||||
authed(rl("guests_delete", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.delete))))
|
||||
|
||||
// CSV import (Block E). Preview is cheap (no DB writes), so we keep
|
||||
// its budget separate from commit's daily-row-add limit.
|
||||
mux.Handle("POST /events/{id}/guests/import/preview",
|
||||
authed(rl("guests_import_preview", 30, time.Hour, userIDKey, http.HandlerFunc(s.csv.preview))))
|
||||
mux.Handle("POST /events/{id}/guests/import",
|
||||
authed(rl("guests_import_commit", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.csv.commit))))
|
||||
mux.Handle("GET /events/{id}/guests/import/template", authed(http.HandlerFunc(s.csv.template)))
|
||||
|
||||
mux.Handle("GET /events/{id}/activity", authed(http.HandlerFunc(s.activity.list)))
|
||||
|
||||
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens",
|
||||
authed(rl("tokens_issue", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.issue))))
|
||||
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens/rotate",
|
||||
authed(rl("tokens_rotate", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.rotate))))
|
||||
mux.Handle("POST /events/{id}/guests/invitations/bulk",
|
||||
authed(rl("tokens_bulk", 10, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.bulkIssue))))
|
||||
|
||||
// Guest-facing endpoints — rate-limited by the access token in the URL
|
||||
// path so an attacker hammering a single invitation is slowed regardless
|
||||
// of their source IP.
|
||||
mux.Handle("GET /access/{token}",
|
||||
rl("access", 60, time.Hour, pathKey("token"), http.HandlerFunc(s.tokens.access)))
|
||||
mux.Handle("POST /rsvp/{token}",
|
||||
rl("rsvp", 10, time.Hour, pathKey("token"), http.HandlerFunc(s.rsvps.submit)))
|
||||
|
||||
// WebSocket endpoint authenticates via single-use ticket on the query
|
||||
// string (see POST /auth/ws-ticket).
|
||||
mux.HandleFunc("GET /ws/events/{id}", s.ws.handle)
|
||||
|
||||
// Unsubscribe (signed token, no auth required — links live in emails).
|
||||
mux.HandleFunc("GET /unsubscribe/{token}", s.unsub.preview)
|
||||
mux.HandleFunc("POST /unsubscribe/{token}", s.unsub.confirm)
|
||||
|
||||
// Provider webhooks. Signature verification is enforced in the handler
|
||||
// once GG_TWILIO_AUTH_TOKEN / GG_SES_WEBHOOK_SECRET are set.
|
||||
mux.HandleFunc("POST /webhooks/twilio/status", s.webhooks.twilio)
|
||||
mux.HandleFunc("POST /webhooks/ses/notifications", s.webhooks.ses)
|
||||
|
||||
// Billing (Block F). /billing/status is safe for everyone — returns
|
||||
// free tier defaults when Stripe is unconfigured or the user has no
|
||||
// subscription, so the frontend's plan page always has something to
|
||||
// render. The action endpoints (checkout, portal) return 503 in dev
|
||||
// without Stripe credentials.
|
||||
mux.Handle("GET /billing/status", authed(http.HandlerFunc(s.billing.status)))
|
||||
mux.Handle("POST /billing/checkout-session", authed(http.HandlerFunc(s.billing.checkoutSession)))
|
||||
mux.Handle("POST /billing/portal", authed(http.HandlerFunc(s.billing.portalSession)))
|
||||
mux.HandleFunc("POST /webhooks/stripe", s.stripeWH.handle)
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusNotFound, "not found")
|
||||
})
|
||||
@@ -128,9 +348,10 @@ func corsMiddleware(next http.Handler) http.Handler {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Device-Fingerprint")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Device-Fingerprint")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/stripe/stripe-go/v82"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// stripeWebhookHandler accepts and verifies Stripe events, then
|
||||
// projects subscription lifecycle changes onto the subscriptions table.
|
||||
// We track only what middleware needs to decide access — tier + status +
|
||||
// period bounds. Invoice events (payment failed / succeeded) are logged
|
||||
// for observability; dunning automation lands in Block F3.
|
||||
type stripeWebhookHandler struct {
|
||||
logger *slog.Logger
|
||||
stripe *billing.Client
|
||||
subs *storage.SubscriptionRepo
|
||||
}
|
||||
|
||||
// POST /webhooks/stripe — signature-verified Stripe event sink.
|
||||
func (h *stripeWebhookHandler) handle(w http.ResponseWriter, r *http.Request) {
|
||||
if h.stripe == nil || !h.stripe.Enabled() {
|
||||
// Not configured on this instance — reject so a misrouted event
|
||||
// isn't silently swallowed. Stripe will retry which is harmless.
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "read body")
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
event, err := h.stripe.VerifyWebhook(body, r.Header.Get("Stripe-Signature"))
|
||||
if err != nil {
|
||||
h.logger.Warn("stripe webhook signature failed", "err", err)
|
||||
writeError(w, http.StatusBadRequest, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "customer.subscription.created", "customer.subscription.updated":
|
||||
h.applySubscription(r, event)
|
||||
case "customer.subscription.deleted":
|
||||
h.applySubscriptionDeleted(r, event)
|
||||
case "invoice.payment_succeeded":
|
||||
// Clear past_due if Stripe says payment caught up. Most flows are
|
||||
// already covered by the subscription.updated event Stripe also
|
||||
// fires — this is belt-and-braces.
|
||||
h.applySubscription(r, event)
|
||||
case "invoice.payment_failed":
|
||||
h.logger.Warn("stripe invoice payment failed", "event_id", event.ID)
|
||||
// Subscription.status will flip to past_due via the
|
||||
// subscription.updated event Stripe fires alongside.
|
||||
default:
|
||||
h.logger.Debug("stripe event ignored", "type", event.Type)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// applySubscription patches the subscriptions row keyed by Stripe
|
||||
// customer id. Best-effort — failures here are logged but don't NACK
|
||||
// the webhook (Stripe would retry forever and the row would never
|
||||
// converge).
|
||||
func (h *stripeWebhookHandler) applySubscription(r *http.Request, event stripe.Event) {
|
||||
var sub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
// Some invoice events carry an Invoice payload — try to extract
|
||||
// the subscription id from there and short-circuit on status.
|
||||
h.logger.Debug("stripe webhook: not a subscription payload", "type", event.Type, "err", err)
|
||||
return
|
||||
}
|
||||
if sub.Customer == nil || sub.Customer.ID == "" {
|
||||
h.logger.Warn("stripe webhook: subscription has no customer", "subscription", sub.ID)
|
||||
return
|
||||
}
|
||||
|
||||
tier := tierFromSubscription(&sub)
|
||||
status := string(sub.Status)
|
||||
cancelAtPeriodEnd := sub.CancelAtPeriodEnd
|
||||
|
||||
// As of API 2024-10-28, current_period_end lives on the subscription
|
||||
// item, not the subscription. We pick the earliest item's end — for
|
||||
// single-item subscriptions (our case) that's the canonical one.
|
||||
var periodEnd *time.Time
|
||||
for _, item := range sub.Items.Data {
|
||||
if item.CurrentPeriodEnd > 0 {
|
||||
t := time.Unix(item.CurrentPeriodEnd, 0).UTC()
|
||||
if periodEnd == nil || t.Before(*periodEnd) {
|
||||
periodEnd = &t
|
||||
}
|
||||
}
|
||||
}
|
||||
subID := sub.ID
|
||||
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
|
||||
StripeSubscriptionID: &subID,
|
||||
Tier: stringPtr(string(tier)),
|
||||
Status: &status,
|
||||
CurrentPeriodEnd: periodEnd,
|
||||
CancelAtPeriodEnd: &cancelAtPeriodEnd,
|
||||
}); err != nil {
|
||||
h.logger.Error("stripe webhook: update subscription failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *stripeWebhookHandler) applySubscriptionDeleted(r *http.Request, event stripe.Event) {
|
||||
var sub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
h.logger.Warn("stripe webhook: bad deleted payload", "err", err)
|
||||
return
|
||||
}
|
||||
if sub.Customer == nil {
|
||||
return
|
||||
}
|
||||
status := "canceled"
|
||||
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
|
||||
Status: &status,
|
||||
}); err != nil {
|
||||
h.logger.Error("stripe webhook: mark canceled failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tierFromSubscription inspects the Stripe price metadata to figure out
|
||||
// which GuestGuard tier this subscription corresponds to. We read a
|
||||
// price-level metadata key `gg_tier` (set in the Stripe dashboard when
|
||||
// you create the Price). Fallback: free.
|
||||
func tierFromSubscription(sub *stripe.Subscription) billing.Tier {
|
||||
if sub == nil || len(sub.Items.Data) == 0 {
|
||||
return billing.TierFree
|
||||
}
|
||||
for _, item := range sub.Items.Data {
|
||||
if item.Price == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := item.Price.Metadata["gg_tier"]; ok {
|
||||
t := billing.Tier(v)
|
||||
if t.Valid() {
|
||||
return t
|
||||
}
|
||||
}
|
||||
// Heuristic fallback for tests / unconfigured prices: look at the
|
||||
// recurring interval and amount tier.
|
||||
if item.Price.Recurring != nil && item.Price.UnitAmount >= 19900 {
|
||||
return billing.TierBusiness
|
||||
}
|
||||
if item.Price.Recurring != nil && item.Price.UnitAmount >= 4900 {
|
||||
return billing.TierPro
|
||||
}
|
||||
}
|
||||
return billing.TierFree
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string { return &s }
|
||||
+329
-14
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -20,25 +21,38 @@ type accessPublisher interface {
|
||||
PublishAccessAttempted(ctx context.Context, evt natspub.AccessAttempted) error
|
||||
}
|
||||
|
||||
type invitationPublisher interface {
|
||||
PublishInvitationSend(ctx context.Context, evt natspub.InvitationSend) error
|
||||
}
|
||||
|
||||
type tokenHandler struct {
|
||||
logger *slog.Logger
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
events *storage.EventRepo
|
||||
accessLogs *storage.AccessLogRepo
|
||||
gen *auth.Generator
|
||||
ttl time.Duration
|
||||
pub accessPublisher
|
||||
logger *slog.Logger
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
events *storage.EventRepo
|
||||
users *storage.UserRepo
|
||||
accessLogs *storage.AccessLogRepo
|
||||
gen *auth.Generator
|
||||
ttl time.Duration
|
||||
pub accessPublisher
|
||||
invitations invitationPublisher
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
type issueTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
Meta *domain.Token `json:"meta"`
|
||||
Token string `json:"token"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
Meta *domain.Token `json:"meta"`
|
||||
InvitationQueued bool `json:"invitation_queued"`
|
||||
InvitationLink string `json:"invitation_link"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/{guest_id}/tokens — issue a token for the guest.
|
||||
func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
@@ -47,6 +61,10 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
guest, err := h.guests.Get(r.Context(), guestID)
|
||||
if err != nil {
|
||||
@@ -78,10 +96,307 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
link := h.invitationLink(raw)
|
||||
invitationQueued := h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
|
||||
|
||||
writeJSON(w, http.StatusCreated, issueTokenResponse{
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
InvitationQueued: invitationQueued,
|
||||
InvitationLink: link,
|
||||
})
|
||||
}
|
||||
|
||||
// queueInvitation publishes an invitation.send event so the notifier can
|
||||
// dispatch a branded email. Best-effort: if any step fails we log and
|
||||
// return false rather than failing the whole token-issue request — the
|
||||
// host still has the raw URL in the response and can re-trigger sending.
|
||||
func (h *tokenHandler) queueInvitation(
|
||||
ctx context.Context,
|
||||
event *domain.Event,
|
||||
guest *domain.Guest,
|
||||
tk *domain.Token,
|
||||
hostID uuid.UUID,
|
||||
rawToken string,
|
||||
) bool {
|
||||
if h.invitations == nil {
|
||||
return false
|
||||
}
|
||||
if guest.Email == nil || *guest.Email == "" {
|
||||
// Phone-only / nameless guests get no email — host shares the link
|
||||
// manually. Show that on the UI so it's not a silent surprise.
|
||||
return false
|
||||
}
|
||||
hostName := ""
|
||||
if h.users != nil {
|
||||
if host, err := h.users.GetByID(ctx, hostID); err == nil && host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
}
|
||||
evt := natspub.InvitationSend{
|
||||
EventID: event.ID,
|
||||
GuestID: guest.ID,
|
||||
TokenID: tk.ID,
|
||||
GuestName: guest.Name,
|
||||
GuestEmail: *guest.Email,
|
||||
HostName: hostName,
|
||||
EventName: event.Name,
|
||||
Venue: event.Venue,
|
||||
EventDate: event.EventDate,
|
||||
Link: h.invitationLink(rawToken),
|
||||
IssuedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := h.invitations.PublishInvitationSend(ctx, evt); err != nil {
|
||||
h.logger.Warn("publish invitation.send (continuing)", "err", err, "guest_id", guest.ID)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// invitationLink renders the public RSVP URL the guest clicks from their
|
||||
// inbox. publicBaseURL is the externally-reachable host (set via
|
||||
// GG_PUBLIC_BASE_URL); access via /rsvp/<token> is intentional — the
|
||||
// frontend page rsvp/[token].vue catches the raw token.
|
||||
func (h *tokenHandler) invitationLink(raw string) string {
|
||||
base := h.publicBaseURL
|
||||
if base == "" {
|
||||
base = "http://localhost:3000"
|
||||
}
|
||||
return base + "/rsvp/" + raw
|
||||
}
|
||||
|
||||
// bulkIssueRequest is the optional JSON body for the bulk-invite call.
|
||||
// An empty body (or missing GuestIDs) means "every guest on the event
|
||||
// who doesn't already have a token".
|
||||
type bulkIssueRequest struct {
|
||||
GuestIDs []string `json:"guest_ids"`
|
||||
}
|
||||
|
||||
type bulkIssueItemError struct {
|
||||
GuestID string `json:"guest_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// bulkIssueToken is one minted invitation. The raw token is returned so
|
||||
// the host's UI can offer a "copy link" affordance after a bulk send
|
||||
// (especially for guests with no email on file) without making another
|
||||
// round-trip. Same data the per-guest issue endpoint already exposes,
|
||||
// scoped to the host who owns the event.
|
||||
type bulkIssueToken struct {
|
||||
GuestID string `json:"guest_id"`
|
||||
Token string `json:"token"`
|
||||
InvitationQueued bool `json:"invitation_queued"`
|
||||
InvitationLink string `json:"invitation_link"`
|
||||
}
|
||||
|
||||
type bulkIssueResponse struct {
|
||||
Issued int `json:"issued"`
|
||||
Queued int `json:"queued"`
|
||||
SkippedExisting int `json:"skipped_existing"`
|
||||
SkippedNoEmail int `json:"skipped_no_email"`
|
||||
Tokens []bulkIssueToken `json:"tokens,omitempty"`
|
||||
Errors []bulkIssueItemError `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/invitations/bulk — generate tokens for every
|
||||
// eligible guest (or the explicit subset) on the event and queue an
|
||||
// invitation email for those with an address. Best-effort: any per-guest
|
||||
// error is reported in the response and doesn't abort the rest.
|
||||
func (h *tokenHandler) bulkIssue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req bulkIssueRequest
|
||||
if r.ContentLength > 0 {
|
||||
// Body is optional; only decode when something was actually sent.
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var onlyIDs []uuid.UUID
|
||||
if len(req.GuestIDs) > 0 {
|
||||
onlyIDs = make([]uuid.UUID, 0, len(req.GuestIDs))
|
||||
for _, raw := range req.GuestIDs {
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid guest id: "+raw)
|
||||
return
|
||||
}
|
||||
onlyIDs = append(onlyIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
guests, err := h.guests.ListGuestsForInvitation(r.Context(), eventID, onlyIDs)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guests")
|
||||
return
|
||||
}
|
||||
|
||||
hostName := ""
|
||||
if h.users != nil {
|
||||
if host, err := h.users.GetByID(r.Context(), hostID); err == nil && host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
}
|
||||
|
||||
resp := bulkIssueResponse{}
|
||||
for _, g := range guests {
|
||||
if g.HasToken {
|
||||
resp.SkippedExisting++
|
||||
continue
|
||||
}
|
||||
raw, hash, err := h.gen.Generate()
|
||||
if err != nil {
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "mint token failed"})
|
||||
continue
|
||||
}
|
||||
tk, err := h.tokens.Create(r.Context(), storage.CreateTokenParams{
|
||||
GuestID: g.ID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(h.ttl),
|
||||
})
|
||||
if err != nil {
|
||||
// Likely a race against the unique constraint (someone else
|
||||
// issued in parallel) — surface but don't fail the batch.
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: err.Error()})
|
||||
continue
|
||||
}
|
||||
resp.Issued++
|
||||
link := h.invitationLink(raw)
|
||||
tokenInfo := bulkIssueToken{
|
||||
GuestID: g.ID.String(),
|
||||
Token: raw,
|
||||
InvitationLink: link,
|
||||
}
|
||||
|
||||
if g.Email == "" {
|
||||
resp.SkippedNoEmail++
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
evt := natspub.InvitationSend{
|
||||
EventID: event.ID,
|
||||
GuestID: g.ID,
|
||||
TokenID: tk.ID,
|
||||
GuestName: g.Name,
|
||||
GuestEmail: g.Email,
|
||||
HostName: hostName,
|
||||
EventName: event.Name,
|
||||
Venue: event.Venue,
|
||||
EventDate: event.EventDate,
|
||||
Link: h.invitationLink(raw),
|
||||
IssuedAt: time.Now().UTC(),
|
||||
}
|
||||
if h.invitations == nil {
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
if err := h.invitations.PublishInvitationSend(r.Context(), evt); err != nil {
|
||||
h.logger.Warn("publish invitation.send (bulk, continuing)", "err", err, "guest_id", g.ID)
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "publish failed"})
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
resp.Queued++
|
||||
tokenInfo.InvitationQueued = true
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type rotateTokenRequest struct {
|
||||
// SendEmail asks the notifier to re-deliver the invitation. False
|
||||
// means "just give me a fresh link" — typical for phone-only guests
|
||||
// where the host shares the new URL via SMS.
|
||||
SendEmail bool `json:"send_email"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/{guest_id}/tokens/rotate — invalidate the
|
||||
// guest's existing invitation link and mint a fresh one. Optionally
|
||||
// re-publishes invitation.send so the notifier re-delivers via email.
|
||||
// The old URL stops working immediately.
|
||||
func (h *tokenHandler) rotate(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
guest, err := h.guests.Get(r.Context(), guestID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guest")
|
||||
return
|
||||
}
|
||||
if guest.EventID != eventID {
|
||||
writeError(w, http.StatusNotFound, "guest not found in event")
|
||||
return
|
||||
}
|
||||
|
||||
var req rotateTokenRequest
|
||||
if r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
raw, hash, err := h.gen.Generate()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to generate token")
|
||||
return
|
||||
}
|
||||
tk, err := h.tokens.RotateForGuest(r.Context(), storage.CreateTokenParams{
|
||||
GuestID: guestID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(h.ttl),
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("rotate token", "err", err, "guest_id", guestID)
|
||||
writeError(w, http.StatusInternalServerError, "failed to rotate token")
|
||||
return
|
||||
}
|
||||
|
||||
link := h.invitationLink(raw)
|
||||
invitationQueued := false
|
||||
if req.SendEmail {
|
||||
invitationQueued = h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, issueTokenResponse{
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
InvitationQueued: invitationQueued,
|
||||
InvitationLink: link,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
)
|
||||
|
||||
type unsubscribeHandler struct {
|
||||
logger *slog.Logger
|
||||
signer *notification.UnsubscribeSigner
|
||||
suppress *notification.SuppressionRepo
|
||||
}
|
||||
|
||||
// GET /unsubscribe/{token} — surface the email address that the token
|
||||
// belongs to so the frontend can show a confirmation page. Honoured even
|
||||
// before the user clicks "Confirm" so they see what's being unsubscribed.
|
||||
func (h *unsubscribeHandler) preview(w http.ResponseWriter, r *http.Request) {
|
||||
if h.signer == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
|
||||
return
|
||||
}
|
||||
email, err := h.signer.Verify(r.PathValue("token"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"email": email})
|
||||
}
|
||||
|
||||
// POST /unsubscribe/{token} — add the email to the suppression list.
|
||||
// Idempotent: clicking the link twice keeps the existing entry.
|
||||
func (h *unsubscribeHandler) confirm(w http.ResponseWriter, r *http.Request) {
|
||||
if h.signer == nil || h.suppress == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
|
||||
return
|
||||
}
|
||||
email, err := h.signer.Verify(r.PathValue("token"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
|
||||
return
|
||||
}
|
||||
if err := h.suppress.Add(r.Context(), email, "user clicked unsubscribe", notification.SuppressionUser); err != nil {
|
||||
h.logger.Error("add suppression", "err", err, "email", email)
|
||||
writeError(w, http.StatusInternalServerError, "failed to unsubscribe")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "unsubscribed", "email": email})
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type userHandler struct {
|
||||
repo *storage.UserRepo
|
||||
}
|
||||
|
||||
type upsertUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// POST /users — idempotent: returns the existing user if the email already
|
||||
// exists, creates one otherwise. This keeps the demo flow simple without
|
||||
// requiring real auth.
|
||||
func (h *userHandler) upsert(w http.ResponseWriter, r *http.Request) {
|
||||
var req upsertUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "email is invalid")
|
||||
return
|
||||
}
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.repo.Create(r.Context(), req.Email, req.Name)
|
||||
if err == nil {
|
||||
writeJSON(w, http.StatusCreated, u)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, domain.ErrEmailTaken) {
|
||||
existing, getErr := h.repo.GetByEmail(r.Context(), req.Email)
|
||||
if getErr != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, existing)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
)
|
||||
|
||||
// webhookHandler accepts provider status notifications and reflects them
|
||||
// onto the notifications table + suppression list.
|
||||
//
|
||||
// Signature verification is intentionally a TODO until the user provisions
|
||||
// real Twilio + SES creds — verifying against test fixtures alone would
|
||||
// give a false sense of security. The endpoint is therefore *not* exposed
|
||||
// publicly until the deployment is ready.
|
||||
type webhookHandler struct {
|
||||
logger *slog.Logger
|
||||
notifs *notification.Repo
|
||||
suppress *notification.SuppressionRepo
|
||||
}
|
||||
|
||||
// POST /webhooks/twilio/status — Twilio status callback (form-encoded).
|
||||
// Fields we care about: MessageSid, MessageStatus (sent|delivered|
|
||||
// undelivered|failed), ErrorCode, To.
|
||||
func (h *webhookHandler) twilio(w http.ResponseWriter, r *http.Request) {
|
||||
if h.notifs == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// TODO(blockD2): verify X-Twilio-Signature with GG_TWILIO_AUTH_TOKEN.
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid form")
|
||||
return
|
||||
}
|
||||
sid := r.PostForm.Get("MessageSid")
|
||||
status := r.PostForm.Get("MessageStatus")
|
||||
if sid == "" || status == "" {
|
||||
writeError(w, http.StatusBadRequest, "missing MessageSid / MessageStatus")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
switch status {
|
||||
case "delivered":
|
||||
_ = h.notifs.MarkDelivered(ctx, sid)
|
||||
case "undelivered", "failed":
|
||||
_ = h.notifs.MarkBounce(ctx, sid, "permanent")
|
||||
}
|
||||
h.logger.Info("twilio status callback", "sid", sid, "status", status)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /webhooks/ses/notifications — SNS-delivered SES notification (JSON).
|
||||
// Handles the two shapes SES uses: bounce + complaint events. Each event
|
||||
// carries the messageId we stored in provider_message_id and an array of
|
||||
// affected recipients.
|
||||
func (h *webhookHandler) ses(w http.ResponseWriter, r *http.Request) {
|
||||
if h.notifs == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// TODO(blockD2): verify SNS signature using the cert URL field.
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "read body")
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"Type"` // "Notification" | "SubscriptionConfirmation"
|
||||
Message string `json:"Message"` // stringified JSON for Notification
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json envelope")
|
||||
return
|
||||
}
|
||||
if envelope.Type == "SubscriptionConfirmation" {
|
||||
// Confirmed by visiting SubscribeURL — manual op-side step.
|
||||
h.logger.Info("ses subscription confirmation received (manual confirm required)")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if envelope.Message == "" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
var inner struct {
|
||||
NotificationType string `json:"notificationType"` // "Bounce" | "Complaint" | "Delivery"
|
||||
Mail struct {
|
||||
MessageID string `json:"messageId"`
|
||||
} `json:"mail"`
|
||||
Bounce struct {
|
||||
BounceType string `json:"bounceType"` // "Permanent" | "Transient"
|
||||
BouncedRecipients []struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
} `json:"bouncedRecipients"`
|
||||
} `json:"bounce"`
|
||||
Complaint struct {
|
||||
ComplainedRecipients []struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
} `json:"complainedRecipients"`
|
||||
} `json:"complaint"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(envelope.Message), &inner); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid inner json")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
switch inner.NotificationType {
|
||||
case "Bounce":
|
||||
bt := "transient"
|
||||
if inner.Bounce.BounceType == "Permanent" {
|
||||
bt = "permanent"
|
||||
}
|
||||
_ = h.notifs.MarkBounce(ctx, inner.Mail.MessageID, bt)
|
||||
if h.suppress != nil && bt == "permanent" {
|
||||
for _, rcp := range inner.Bounce.BouncedRecipients {
|
||||
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses permanent bounce", notification.SuppressionBounce)
|
||||
}
|
||||
}
|
||||
case "Complaint":
|
||||
_ = h.notifs.MarkComplaint(ctx, inner.Mail.MessageID)
|
||||
if h.suppress != nil {
|
||||
for _, rcp := range inner.Complaint.ComplainedRecipients {
|
||||
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses complaint", notification.SuppressionComplaint)
|
||||
}
|
||||
}
|
||||
case "Delivery":
|
||||
_ = h.notifs.MarkDelivered(ctx, inner.Mail.MessageID)
|
||||
}
|
||||
h.logger.Info("ses notification", "type", inner.NotificationType, "message_id", inner.Mail.MessageID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Compile-time check that ctx is unused in package — silences linter on
|
||||
// some Go versions when the file would otherwise import context only for
|
||||
// the handler signatures.
|
||||
var _ = context.Background
|
||||
@@ -0,0 +1,51 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type wsTicketHandler struct {
|
||||
tickets *wsTicketStore
|
||||
events *storage.EventRepo
|
||||
}
|
||||
|
||||
type wsTicketResponse struct {
|
||||
Ticket string `json:"ticket"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// POST /auth/ws-ticket — requireAuth-protected; body { "event_id": "<uuid>" }.
|
||||
// Returns a single-use ticket valid for ~60 seconds. The frontend appends it
|
||||
// as `?ticket=…` on the WebSocket URL.
|
||||
func (h *wsTicketHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
EventID string `json:"event_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
eventID, ok := parseRawUUID(w, "event_id", req.EventID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tok, exp, err := h.tickets.Mint(hostID, eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to mint ticket")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, wsTicketResponse{Ticket: tok, ExpiresAt: exp})
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// wsTicketStore mints short-lived single-use tickets that authorise a
|
||||
// WebSocket handshake. The plan calls this option 3 in Block B: cookies
|
||||
// don't reach the WS handshake on cross-origin setups and a JWT in the URL
|
||||
// would leak to logs; a one-shot ticket sidesteps both.
|
||||
//
|
||||
// Block B keeps this in-process. When the API runs more than one replica
|
||||
// this needs to move to Redis (Block C territory).
|
||||
type wsTicketStore struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]wsTicketEntry
|
||||
ttl time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type wsTicketEntry struct {
|
||||
userID uuid.UUID
|
||||
eventID uuid.UUID
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newWSTicketStore(ttl time.Duration) *wsTicketStore {
|
||||
return &wsTicketStore{
|
||||
entries: make(map[string]wsTicketEntry),
|
||||
ttl: ttl,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Mint returns a fresh URL-safe ticket bound to userID + eventID.
|
||||
func (s *wsTicketStore) Mint(userID, eventID uuid.UUID) (string, time.Time, error) {
|
||||
buf := make([]byte, 24)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
tok := base64.RawURLEncoding.EncodeToString(buf)
|
||||
exp := s.now().Add(s.ttl)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked()
|
||||
s.entries[tok] = wsTicketEntry{userID: userID, eventID: eventID, expiresAt: exp}
|
||||
return tok, exp, nil
|
||||
}
|
||||
|
||||
// Consume removes the ticket and returns the bound (userID, eventID) if it
|
||||
// was valid. A ticket is single-use — replaying it fails.
|
||||
func (s *wsTicketStore) Consume(token string) (uuid.UUID, uuid.UUID, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
entry, ok := s.entries[token]
|
||||
if !ok {
|
||||
return uuid.Nil, uuid.Nil, false
|
||||
}
|
||||
delete(s.entries, token)
|
||||
if s.now().After(entry.expiresAt) {
|
||||
return uuid.Nil, uuid.Nil, false
|
||||
}
|
||||
return entry.userID, entry.eventID, true
|
||||
}
|
||||
|
||||
// sweepLocked drops expired entries opportunistically. Cheap because we
|
||||
// usually only hold dozens of tickets at a time.
|
||||
func (s *wsTicketStore) sweepLocked() {
|
||||
now := s.now()
|
||||
for k, v := range s.entries {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(s.entries, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
+23
-3
@@ -89,16 +89,36 @@ func (h *Hub) remove(eventID uuid.UUID, s *subscriber) {
|
||||
}
|
||||
|
||||
type wsHandler struct {
|
||||
logger *slog.Logger
|
||||
hub *Hub
|
||||
logger *slog.Logger
|
||||
hub *Hub
|
||||
tickets *wsTicketStore
|
||||
}
|
||||
|
||||
// GET /ws/events/{id} — dashboard live feed for one event.
|
||||
// GET /ws/events/{id}?ticket=... — dashboard live feed for one event.
|
||||
//
|
||||
// The handshake is authorised by a single-use ticket minted via
|
||||
// POST /auth/ws-ticket (option 3 from the Block B plan). The ticket binds
|
||||
// the connecting user to a specific event_id; we reject if either is
|
||||
// missing or doesn't match the URL path.
|
||||
func (h *wsHandler) handle(w http.ResponseWriter, r *http.Request) {
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
rawTicket := r.URL.Query().Get("ticket")
|
||||
if rawTicket == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing ticket")
|
||||
return
|
||||
}
|
||||
_, ticketEventID, valid := h.tickets.Consume(rawTicket)
|
||||
if !valid {
|
||||
writeError(w, http.StatusUnauthorized, "invalid or expired ticket")
|
||||
return
|
||||
}
|
||||
if ticketEventID != eventID {
|
||||
writeError(w, http.StatusForbidden, "ticket does not match event")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
// In dev the frontend runs on a different origin (localhost:3000 → localhost:8080).
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// EmailSender delivers transactional auth emails (verification, reset).
|
||||
// Block A ships LogSender so dev environments work without Twilio/SES.
|
||||
// Block D replaces this with a real SES-backed sender.
|
||||
type EmailSender interface {
|
||||
SendVerification(ctx context.Context, to, name, link string) error
|
||||
SendPasswordReset(ctx context.Context, to, name, link string) error
|
||||
}
|
||||
|
||||
type LogEmailSender struct {
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func (l LogEmailSender) SendVerification(_ context.Context, to, name, link string) error {
|
||||
l.Logger.Info("auth email (stub): verification",
|
||||
"to", to, "name", name, "link", link,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l LogEmailSender) SendPasswordReset(_ context.Context, to, name, link string) error {
|
||||
l.Logger.Info("auth email (stub): password reset",
|
||||
"to", to, "name", name, "link", link,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidJWT = errors.New("invalid token")
|
||||
ErrExpiredJWT = errors.New("token expired")
|
||||
)
|
||||
|
||||
type AccessClaims struct {
|
||||
UserID uuid.UUID `json:"sub_uuid"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTSigner struct {
|
||||
secret []byte
|
||||
ttl time.Duration
|
||||
issuer string
|
||||
parser *jwt.Parser
|
||||
}
|
||||
|
||||
func NewJWTSigner(secret string, ttl time.Duration, issuer string) (*JWTSigner, error) {
|
||||
if len(secret) < 32 {
|
||||
return nil, fmt.Errorf("jwt secret must be at least 32 bytes")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return nil, fmt.Errorf("jwt ttl must be positive")
|
||||
}
|
||||
return &JWTSigner{
|
||||
secret: []byte(secret),
|
||||
ttl: ttl,
|
||||
issuer: issuer,
|
||||
parser: jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithExpirationRequired(),
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *JWTSigner) Issue(userID uuid.UUID, now time.Time) (string, time.Time, error) {
|
||||
exp := now.Add(s.ttl)
|
||||
claims := AccessClaims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: s.issuer,
|
||||
Subject: userID.String(),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)),
|
||||
ExpiresAt: jwt.NewNumericDate(exp),
|
||||
ID: uuid.NewString(),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := tok.SignedString(s.secret)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return signed, exp, nil
|
||||
}
|
||||
|
||||
func (s *JWTSigner) Parse(raw string) (*AccessClaims, error) {
|
||||
claims := &AccessClaims{}
|
||||
tok, err := s.parser.ParseWithClaims(raw, claims, func(t *jwt.Token) (any, error) {
|
||||
return s.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrExpiredJWT
|
||||
}
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
if claims.UserID == uuid.Nil {
|
||||
// Fallback for tokens that only carry Subject (defensive — we always
|
||||
// set UserID on issue).
|
||||
parsed, perr := uuid.Parse(claims.Subject)
|
||||
if perr != nil {
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
claims.UserID = parsed
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *JWTSigner) TTL() time.Duration { return s.ttl }
|
||||
@@ -0,0 +1,82 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const testSecret = "test-secret-must-be-at-least-32-bytes-long-xx"
|
||||
|
||||
func TestJWTRoundTrip(t *testing.T) {
|
||||
s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test")
|
||||
if err != nil {
|
||||
t.Fatalf("signer: %v", err)
|
||||
}
|
||||
uid := uuid.New()
|
||||
tok, exp, err := s.Issue(uid, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("issue: %v", err)
|
||||
}
|
||||
if tok == "" {
|
||||
t.Fatal("empty token")
|
||||
}
|
||||
if time.Until(exp) <= 0 {
|
||||
t.Fatalf("expiry in past: %v", exp)
|
||||
}
|
||||
claims, err := s.Parse(tok)
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if claims.UserID != uid {
|
||||
t.Fatalf("user mismatch: got %s want %s", claims.UserID, uid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTExpired(t *testing.T) {
|
||||
s, err := NewJWTSigner(testSecret, 1*time.Second, "guestguard-test")
|
||||
if err != nil {
|
||||
t.Fatalf("signer: %v", err)
|
||||
}
|
||||
tok, _, err := s.Issue(uuid.New(), time.Now().Add(-1*time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("issue: %v", err)
|
||||
}
|
||||
if _, err := s.Parse(tok); !errors.Is(err, ErrExpiredJWT) {
|
||||
t.Fatalf("expected ErrExpiredJWT, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTTamper(t *testing.T) {
|
||||
s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test")
|
||||
if err != nil {
|
||||
t.Fatalf("signer: %v", err)
|
||||
}
|
||||
tok, _, _ := s.Issue(uuid.New(), time.Now())
|
||||
// Flip a character in the signature segment.
|
||||
tampered := tok[:len(tok)-1] + "a"
|
||||
if tampered == tok {
|
||||
tampered = tok[:len(tok)-1] + "b"
|
||||
}
|
||||
if _, err := s.Parse(tampered); !errors.Is(err, ErrInvalidJWT) {
|
||||
t.Fatalf("expected ErrInvalidJWT, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTSecretTooShort(t *testing.T) {
|
||||
if _, err := NewJWTSigner("short", time.Minute, "x"); err == nil {
|
||||
t.Fatal("expected error for short secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpaqueTokenHashStable(t *testing.T) {
|
||||
raw, hash, err := NewOpaqueToken()
|
||||
if err != nil {
|
||||
t.Fatalf("mint: %v", err)
|
||||
}
|
||||
if got := HashOpaque(raw); got != hash {
|
||||
t.Fatalf("hash mismatch: got %s want %s", got, hash)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// LockoutTracker counts consecutive failed logins per email and trips a
|
||||
// lockout flag after a threshold. The lock is keyed by user_id (once we
|
||||
// know it from the email) so that resetting the password — which we do via
|
||||
// /auth/reset-password — can clear it cleanly.
|
||||
//
|
||||
// Why two keys? The failure counter must work even when the email maps to
|
||||
// no user (otherwise an attacker probing addresses just gets unlimited
|
||||
// tries). The lock flag only exists once we've matched an actual account.
|
||||
type LockoutTracker struct {
|
||||
client *redis.Client
|
||||
threshold int
|
||||
window time.Duration // how long failures linger before counters reset
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewLockoutTracker(client *redis.Client, threshold int, window time.Duration) *LockoutTracker {
|
||||
return &LockoutTracker{
|
||||
client: client,
|
||||
threshold: threshold,
|
||||
window: window,
|
||||
prefix: "auth",
|
||||
}
|
||||
}
|
||||
|
||||
func (t *LockoutTracker) failKey(email string) string {
|
||||
h := sha256.Sum256([]byte(strings.ToLower(strings.TrimSpace(email))))
|
||||
return fmt.Sprintf("%s:login_fail:%s", t.prefix, hex.EncodeToString(h[:]))
|
||||
}
|
||||
|
||||
func (t *LockoutTracker) lockKey(uid uuid.UUID) string {
|
||||
return fmt.Sprintf("%s:locked:%s", t.prefix, uid.String())
|
||||
}
|
||||
|
||||
// IsLocked reports whether the given user's account is currently locked.
|
||||
func (t *LockoutTracker) IsLocked(ctx context.Context, uid uuid.UUID) (bool, error) {
|
||||
if t == nil || t.client == nil {
|
||||
return false, nil
|
||||
}
|
||||
v, err := t.client.Exists(ctx, t.lockKey(uid)).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return v > 0, nil
|
||||
}
|
||||
|
||||
// RecordFailure increments the failure counter for the email and, if it
|
||||
// crosses the threshold, sets the lock flag for the given user id.
|
||||
// Returns (locked, error). A nil userID is fine — the counter still ticks
|
||||
// up so probing nonexistent accounts is also rate-limited.
|
||||
func (t *LockoutTracker) RecordFailure(ctx context.Context, email string, userID *uuid.UUID) (bool, error) {
|
||||
if t == nil || t.client == nil {
|
||||
return false, nil
|
||||
}
|
||||
key := t.failKey(email)
|
||||
n, err := t.client.Incr(ctx, key).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if n == 1 {
|
||||
_ = t.client.Expire(ctx, key, t.window).Err()
|
||||
}
|
||||
if int(n) >= t.threshold && userID != nil {
|
||||
// Keep the lock until password reset clears it. 7-day fallback TTL
|
||||
// so a permanently abandoned account doesn't pile up forever.
|
||||
if err := t.client.Set(ctx, t.lockKey(*userID), "1", 7*24*time.Hour).Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ClearForUser drops both the lock flag and any in-flight failure counter
|
||||
// for the user's email. Called from /auth/reset-password.
|
||||
func (t *LockoutTracker) ClearForUser(ctx context.Context, uid uuid.UUID, email string) error {
|
||||
if t == nil || t.client == nil {
|
||||
return nil
|
||||
}
|
||||
pipe := t.client.Pipeline()
|
||||
pipe.Del(ctx, t.lockKey(uid))
|
||||
pipe.Del(ctx, t.failKey(email))
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearOnSuccess drops only the failure counter — used after a successful
|
||||
// login to forgive prior typos.
|
||||
func (t *LockoutTracker) ClearOnSuccess(ctx context.Context, email string) {
|
||||
if t == nil || t.client == nil {
|
||||
return
|
||||
}
|
||||
_ = t.client.Del(ctx, t.failKey(email)).Err()
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const bcryptCost = 12
|
||||
|
||||
var (
|
||||
ErrPasswordMismatch = errors.New("password mismatch")
|
||||
ErrPasswordTooShort = errors.New("password must be at least 8 characters")
|
||||
ErrPasswordTooLong = errors.New("password must be at most 72 characters")
|
||||
)
|
||||
|
||||
type PasswordHasher struct {
|
||||
cost int
|
||||
}
|
||||
|
||||
func NewPasswordHasher() *PasswordHasher {
|
||||
return &PasswordHasher{cost: bcryptCost}
|
||||
}
|
||||
|
||||
func (h *PasswordHasher) Hash(raw string) (string, error) {
|
||||
if err := ValidatePassword(raw); err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := bcrypt.GenerateFromPassword([]byte(raw), h.cost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *PasswordHasher) Verify(hash, raw string) error {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(raw)); err != nil {
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return ErrPasswordMismatch
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePassword enforces a minimum length and the bcrypt-imposed maximum.
|
||||
// bcrypt silently truncates inputs over 72 bytes, which would let a user set
|
||||
// a 100-character password and successfully log in with the first 72; reject
|
||||
// at the boundary instead.
|
||||
func ValidatePassword(raw string) error {
|
||||
if len(raw) < 8 {
|
||||
return ErrPasswordTooShort
|
||||
}
|
||||
if len(raw) > 72 {
|
||||
return ErrPasswordTooLong
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordHasherRoundTrip(t *testing.T) {
|
||||
h := NewPasswordHasher()
|
||||
hash, err := h.Hash("correct-horse-battery-staple")
|
||||
if err != nil {
|
||||
t.Fatalf("hash: %v", err)
|
||||
}
|
||||
if hash == "" {
|
||||
t.Fatal("empty hash")
|
||||
}
|
||||
if err := h.Verify(hash, "correct-horse-battery-staple"); err != nil {
|
||||
t.Fatalf("verify correct: %v", err)
|
||||
}
|
||||
if err := h.Verify(hash, "wrong"); !errors.Is(err, ErrPasswordMismatch) {
|
||||
t.Fatalf("verify wrong: got %v want ErrPasswordMismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pw string
|
||||
want error
|
||||
}{
|
||||
{"too short", "1234567", ErrPasswordTooShort},
|
||||
{"min length ok", "12345678", nil},
|
||||
{"too long", strings.Repeat("a", 73), ErrPasswordTooLong},
|
||||
{"max length ok", strings.Repeat("a", 72), nil},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := ValidatePassword(tc.pw); !errors.Is(got, tc.want) {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// NewOpaqueToken returns a 32-byte URL-safe random token plus its SHA-256 hex
|
||||
// digest. The raw value is shown once (in a link); only the digest is stored.
|
||||
func NewOpaqueToken() (raw, hash string, err error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
raw = base64.RawURLEncoding.EncodeToString(buf)
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
hash = hex.EncodeToString(sum[:])
|
||||
return raw, hash, nil
|
||||
}
|
||||
|
||||
func HashOpaque(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/stripe/stripe-go/v82"
|
||||
"github.com/stripe/stripe-go/v82/billingportal/session"
|
||||
csession "github.com/stripe/stripe-go/v82/checkout/session"
|
||||
"github.com/stripe/stripe-go/v82/customer"
|
||||
"github.com/stripe/stripe-go/v82/webhook"
|
||||
)
|
||||
|
||||
// Client wraps the Stripe SDK with the subset of calls the API needs.
|
||||
// Concentrates env-var reads + price-ID lookups in one place so handlers
|
||||
// stay focused on HTTP, not Stripe plumbing.
|
||||
type Client struct {
|
||||
secretKey string
|
||||
webhookSecret string
|
||||
prices map[Tier]string
|
||||
}
|
||||
|
||||
// Config is the env-derived configuration the Client needs.
|
||||
type Config struct {
|
||||
SecretKey string
|
||||
WebhookSecret string
|
||||
PriceProMonthly string
|
||||
PriceBusiness string
|
||||
}
|
||||
|
||||
// NewClient validates required fields and returns a configured client.
|
||||
// Returns (nil, nil) when SecretKey is empty — callers treat that as
|
||||
// "billing disabled" and degrade gracefully (free tier for everyone, no
|
||||
// /billing/* endpoints exposed).
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
if cfg.SecretKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
stripe.Key = cfg.SecretKey
|
||||
|
||||
c := &Client{
|
||||
secretKey: cfg.SecretKey,
|
||||
webhookSecret: cfg.WebhookSecret,
|
||||
prices: map[Tier]string{
|
||||
TierPro: cfg.PriceProMonthly,
|
||||
TierBusiness: cfg.PriceBusiness,
|
||||
},
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Enabled reports whether the client was constructed with a Stripe key.
|
||||
func (c *Client) Enabled() bool { return c != nil && c.secretKey != "" }
|
||||
|
||||
// PriceFor returns the Stripe Price ID for a tier or an error if it
|
||||
// hasn't been configured — checkout will fail loudly rather than
|
||||
// silently send the customer to an empty checkout page.
|
||||
func (c *Client) PriceFor(tier Tier) (string, error) {
|
||||
id, ok := c.prices[tier]
|
||||
if !ok || id == "" {
|
||||
return "", fmt.Errorf("billing: no Stripe price configured for tier %q", tier)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// CreateOrGetCustomer returns the Stripe customer id for the given
|
||||
// (user_id, email). If `existingID` is non-empty we trust it; otherwise
|
||||
// we create a new Stripe customer and let the caller persist the id.
|
||||
func (c *Client) CreateOrGetCustomer(userID, email, name, existingID string) (string, error) {
|
||||
if !c.Enabled() {
|
||||
return "", errors.New("billing: disabled")
|
||||
}
|
||||
if existingID != "" {
|
||||
return existingID, nil
|
||||
}
|
||||
params := &stripe.CustomerParams{
|
||||
Email: stripe.String(email),
|
||||
Name: stripe.String(name),
|
||||
Metadata: map[string]string{"gg_user_id": userID},
|
||||
}
|
||||
cust, err := customer.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stripe customer create: %w", err)
|
||||
}
|
||||
return cust.ID, nil
|
||||
}
|
||||
|
||||
// CheckoutSessionParams collects the inputs CreateCheckoutSession needs.
|
||||
// Keeping them in a struct so future fields (coupon codes, trial periods,
|
||||
// referral metadata) drop in without breaking callers.
|
||||
type CheckoutSessionParams struct {
|
||||
CustomerID string
|
||||
PriceID string
|
||||
SuccessURL string
|
||||
CancelURL string
|
||||
}
|
||||
|
||||
// CreateCheckoutSession returns the URL the frontend redirects the user
|
||||
// to. Subscription mode — recurring billing for Pro/Business.
|
||||
func (c *Client) CreateCheckoutSession(p CheckoutSessionParams) (string, error) {
|
||||
if !c.Enabled() {
|
||||
return "", errors.New("billing: disabled")
|
||||
}
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
Customer: stripe.String(p.CustomerID),
|
||||
SuccessURL: stripe.String(p.SuccessURL),
|
||||
CancelURL: stripe.String(p.CancelURL),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{{
|
||||
Price: stripe.String(p.PriceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
}},
|
||||
AllowPromotionCodes: stripe.Bool(true),
|
||||
}
|
||||
sess, err := csession.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stripe checkout session: %w", err)
|
||||
}
|
||||
return sess.URL, nil
|
||||
}
|
||||
|
||||
// CreatePortalSession returns a URL to the Stripe-hosted customer
|
||||
// portal so the user can manage payment methods, cancel, view invoices.
|
||||
func (c *Client) CreatePortalSession(customerID, returnURL string) (string, error) {
|
||||
if !c.Enabled() {
|
||||
return "", errors.New("billing: disabled")
|
||||
}
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(customerID),
|
||||
ReturnURL: stripe.String(returnURL),
|
||||
}
|
||||
sess, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stripe portal session: %w", err)
|
||||
}
|
||||
return sess.URL, nil
|
||||
}
|
||||
|
||||
// VerifyWebhook validates the Stripe signature header and returns the
|
||||
// parsed event. Refuses to verify when no webhook secret is configured —
|
||||
// no shared secret means anyone can POST forged events, so the route
|
||||
// should reject everything in that case (the caller does that check).
|
||||
//
|
||||
// We pass IgnoreAPIVersionMismatch because Stripe accounts can be on a
|
||||
// newer API version than the SDK we're built against. Event payloads
|
||||
// are designed to be forward-compatible — the SDK warns about the skew
|
||||
// but the deserialised event is still safe to use. Strict matching
|
||||
// would mean we'd have to upgrade the SDK in lockstep with whatever
|
||||
// Stripe rolls out, defeating the point of having an SDK.
|
||||
func (c *Client) VerifyWebhook(body []byte, sigHeader string) (stripe.Event, error) {
|
||||
if c.webhookSecret == "" {
|
||||
return stripe.Event{}, errors.New("billing: no webhook secret configured")
|
||||
}
|
||||
return webhook.ConstructEventWithOptions(body, sigHeader, c.webhookSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
// Package billing models GuestGuard's subscription tiers and Stripe
|
||||
// integration. Plan limits live here so handler + middleware layers don't
|
||||
// hard-code numbers — change a value, restart the API, and the cap moves.
|
||||
package billing
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Tier is the user's current subscription tier. Stored as text in the
|
||||
// subscriptions table.
|
||||
type Tier string
|
||||
|
||||
const (
|
||||
TierFree Tier = "free"
|
||||
TierPro Tier = "pro"
|
||||
TierBusiness Tier = "business"
|
||||
)
|
||||
|
||||
func (t Tier) Valid() bool {
|
||||
switch t {
|
||||
case TierFree, TierPro, TierBusiness:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Limits enforces what a tier may do. -1 means unlimited.
|
||||
type Limits struct {
|
||||
EventsPerMonth int
|
||||
GuestsPerEvent int
|
||||
}
|
||||
|
||||
// TierLimits is the canonical plan-limits table. Matches docs/TIER1_PLAN.md
|
||||
// Block F pricing (placeholder until market validation).
|
||||
var TierLimits = map[Tier]Limits{
|
||||
TierFree: {EventsPerMonth: 1, GuestsPerEvent: 50},
|
||||
TierPro: {EventsPerMonth: 10, GuestsPerEvent: 1000},
|
||||
TierBusiness: {EventsPerMonth: -1, GuestsPerEvent: 5000},
|
||||
}
|
||||
|
||||
// LimitsFor returns the limits for a tier, defaulting to Free for unknown
|
||||
// strings — defensive, so a typo or future-tier in the DB never grants
|
||||
// unlimited access.
|
||||
func LimitsFor(t Tier) Limits {
|
||||
if l, ok := TierLimits[t]; ok {
|
||||
return l
|
||||
}
|
||||
return TierLimits[TierFree]
|
||||
}
|
||||
|
||||
// StatusGrantsAccess returns true if a subscription with this status
|
||||
// should be treated as paid. Mirrors the unique-index predicate in
|
||||
// 0005_billing.up.sql.
|
||||
func StatusGrantsAccess(status string) bool {
|
||||
switch status {
|
||||
case "active", "past_due", "trialing":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LimitError describes a denied-by-policy outcome. The handler layer
|
||||
// turns this into a 402 with a JSON body the frontend uses to render
|
||||
// the upgrade modal.
|
||||
type LimitError struct {
|
||||
Reason string
|
||||
Tier Tier
|
||||
Limit int
|
||||
Used int
|
||||
}
|
||||
|
||||
func (e *LimitError) Error() string {
|
||||
return fmt.Sprintf("billing: %s (tier=%s used=%d limit=%d)", e.Reason, e.Tier, e.Used, e.Limit)
|
||||
}
|
||||
@@ -12,11 +12,56 @@ type Config struct {
|
||||
HTTPAddr string
|
||||
DatabaseURL string
|
||||
NATSURL string
|
||||
RedisAddr string
|
||||
FraudGRPCAddr string
|
||||
FraudGRPCTimeout time.Duration
|
||||
ShutdownTimeout time.Duration
|
||||
TokenSecret string
|
||||
TokenTTL time.Duration
|
||||
|
||||
// Auth
|
||||
JWTSecret string
|
||||
JWTIssuer string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
PublicBaseURL string
|
||||
RefreshCookieDomain string
|
||||
RefreshCookieSecure bool
|
||||
|
||||
// Notifications (Block D). Empty values leave the log-stub adapter in
|
||||
// place so the service still boots without AWS / Twilio creds.
|
||||
SESRegion string
|
||||
SESFromEmail string
|
||||
SESFromName string
|
||||
SESConfigurationSet string
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFromEmail string
|
||||
SMTPFromName string
|
||||
SMTPTLS string // "starttls" | "implicit" | "none" (default starttls)
|
||||
|
||||
ResendAPIKey string
|
||||
ResendFromEmail string
|
||||
ResendFromName string
|
||||
|
||||
TwilioAccountSID string
|
||||
TwilioAuthToken string
|
||||
TwilioFromNumber string
|
||||
|
||||
// Billing (Block F). Empty StripeSecretKey leaves billing disabled —
|
||||
// all users get free-tier limits, /billing/* returns 503. Lets the
|
||||
// service boot in dev without Stripe creds.
|
||||
StripeSecretKey string
|
||||
StripeWebhookSecret string
|
||||
StripePricePro string // Stripe Price ID for Pro monthly
|
||||
StripePriceBusiness string // Stripe Price ID for Business monthly
|
||||
|
||||
UnsubscribeSecret string // HMAC key for signing unsubscribe links
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -25,11 +70,50 @@ func Load() (*Config, error) {
|
||||
HTTPAddr: getenv("GG_HTTP_ADDR", ":8080"),
|
||||
DatabaseURL: getenv("GG_DATABASE_URL", "postgres://guestguard:guestguard@localhost:5432/guestguard?sslmode=disable"),
|
||||
NATSURL: getenv("GG_NATS_URL", "nats://localhost:4222"),
|
||||
RedisAddr: getenv("GG_REDIS_ADDR", "localhost:6379"),
|
||||
FraudGRPCAddr: getenv("GG_FRAUD_GRPC_ADDR", "fraud-engine:9091"),
|
||||
FraudGRPCTimeout: getenvDuration("GG_FRAUD_GRPC_TIMEOUT", 250*time.Millisecond),
|
||||
ShutdownTimeout: getenvDuration("GG_SHUTDOWN_TIMEOUT", 15*time.Second),
|
||||
TokenSecret: os.Getenv("GG_TOKEN_SECRET"),
|
||||
TokenTTL: getenvDuration("GG_TOKEN_TTL", 30*24*time.Hour),
|
||||
|
||||
JWTSecret: os.Getenv("GG_JWT_SECRET"),
|
||||
JWTIssuer: getenv("GG_JWT_ISSUER", "guestguard"),
|
||||
AccessTokenTTL: getenvDuration("GG_ACCESS_TOKEN_TTL", 15*time.Minute),
|
||||
RefreshTokenTTL: getenvDuration("GG_REFRESH_TOKEN_TTL", 30*24*time.Hour),
|
||||
EmailVerificationTTL: getenvDuration("GG_EMAIL_VERIFICATION_TTL", 24*time.Hour),
|
||||
PasswordResetTTL: getenvDuration("GG_PASSWORD_RESET_TTL", 1*time.Hour),
|
||||
PublicBaseURL: getenv("GG_PUBLIC_BASE_URL", "http://localhost:3000"),
|
||||
RefreshCookieDomain: os.Getenv("GG_REFRESH_COOKIE_DOMAIN"),
|
||||
RefreshCookieSecure: getenvBool("GG_REFRESH_COOKIE_SECURE", false),
|
||||
|
||||
SESRegion: getenv("GG_SES_REGION", "us-east-1"),
|
||||
SESFromEmail: os.Getenv("GG_SES_FROM_EMAIL"),
|
||||
SESFromName: getenv("GG_SES_FROM_NAME", "GuestGuard"),
|
||||
SESConfigurationSet: os.Getenv("GG_SES_CONFIGURATION_SET"),
|
||||
|
||||
SMTPHost: os.Getenv("GG_SMTP_HOST"),
|
||||
SMTPPort: getenvInt("GG_SMTP_PORT", 587),
|
||||
SMTPUsername: os.Getenv("GG_SMTP_USERNAME"),
|
||||
SMTPPassword: os.Getenv("GG_SMTP_PASSWORD"),
|
||||
SMTPFromEmail: os.Getenv("GG_SMTP_FROM_EMAIL"),
|
||||
SMTPFromName: getenv("GG_SMTP_FROM_NAME", "GuestGuard"),
|
||||
SMTPTLS: getenv("GG_SMTP_TLS", "starttls"),
|
||||
|
||||
ResendAPIKey: os.Getenv("GG_RESEND_API_KEY"),
|
||||
ResendFromEmail: os.Getenv("GG_RESEND_FROM_EMAIL"),
|
||||
ResendFromName: getenv("GG_RESEND_FROM_NAME", "GuestGuard"),
|
||||
|
||||
TwilioAccountSID: os.Getenv("GG_TWILIO_ACCOUNT_SID"),
|
||||
TwilioAuthToken: os.Getenv("GG_TWILIO_AUTH_TOKEN"),
|
||||
TwilioFromNumber: os.Getenv("GG_TWILIO_FROM_NUMBER"),
|
||||
|
||||
StripeSecretKey: os.Getenv("GG_STRIPE_SECRET_KEY"),
|
||||
StripeWebhookSecret: os.Getenv("GG_STRIPE_WEBHOOK_SECRET"),
|
||||
StripePricePro: os.Getenv("GG_STRIPE_PRICE_PRO"),
|
||||
StripePriceBusiness: os.Getenv("GG_STRIPE_PRICE_BUSINESS"),
|
||||
|
||||
UnsubscribeSecret: os.Getenv("GG_UNSUBSCRIBE_SECRET"),
|
||||
}
|
||||
|
||||
if cfg.Env == "production" && cfg.TokenSecret == "" {
|
||||
@@ -39,9 +123,52 @@ func Load() (*Config, error) {
|
||||
cfg.TokenSecret = "dev-only-insecure-secret-change-me"
|
||||
}
|
||||
|
||||
if cfg.Env == "production" && cfg.JWTSecret == "" {
|
||||
return nil, fmt.Errorf("GG_JWT_SECRET is required in production")
|
||||
}
|
||||
if cfg.JWTSecret == "" {
|
||||
cfg.JWTSecret = "dev-only-insecure-jwt-secret-change-me-32+bytes"
|
||||
}
|
||||
if len(cfg.JWTSecret) < 32 {
|
||||
return nil, fmt.Errorf("GG_JWT_SECRET must be at least 32 bytes")
|
||||
}
|
||||
|
||||
if cfg.UnsubscribeSecret == "" {
|
||||
// Same dev fallback shape as the other secrets — production refuses
|
||||
// to boot without it.
|
||||
if cfg.Env == "production" {
|
||||
return nil, fmt.Errorf("GG_UNSUBSCRIBE_SECRET is required in production")
|
||||
}
|
||||
cfg.UnsubscribeSecret = "dev-only-insecure-unsubscribe-secret-change-me"
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getenvInt(key string, fallback int) int {
|
||||
v, ok := os.LookupEnv(key)
|
||||
if !ok || v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getenvBool(key string, fallback bool) bool {
|
||||
v, ok := os.LookupEnv(key)
|
||||
if !ok || v == "" {
|
||||
return fallback
|
||||
}
|
||||
b, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func getenv(key, fallback string) string {
|
||||
if v, ok := os.LookupEnv(key); ok && v != "" {
|
||||
return v
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
// Package csvimport parses a guest-list CSV into structured rows, with
|
||||
// tolerant header detection (Excel, Numbers, Google Sheets variants) and
|
||||
// per-row validation. Streaming-friendly so a 5,000-row import doesn't
|
||||
// load the entire file into a slice before we know if column 1 is junk.
|
||||
package csvimport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// Row is a single validated guest. Empty Email / Phone are allowed (a
|
||||
// phone-only or name-only guest is valid per the plan).
|
||||
type Row struct {
|
||||
Name string
|
||||
Email string
|
||||
Phone string
|
||||
PlusOnes int
|
||||
}
|
||||
|
||||
// RowError flags one row with the human-readable reason it can't be
|
||||
// imported. The line number is 1-based and matches the source CSV
|
||||
// (header counts as line 1, first data row is line 2) so the frontend
|
||||
// can highlight the offending row.
|
||||
type RowError struct {
|
||||
Row int `json:"row"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// Result is the outcome of one parse pass.
|
||||
type Result struct {
|
||||
Rows []Row `json:"rows,omitempty"`
|
||||
Errors []RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"` // total data rows seen (excluding header)
|
||||
}
|
||||
|
||||
// Options tune limits + behaviour.
|
||||
type Options struct {
|
||||
MaxRows int // hard cap; rows beyond MaxRows return an error instead of being silently dropped
|
||||
}
|
||||
|
||||
const DefaultMaxRows = 5000
|
||||
|
||||
// Strict E.164: optional leading +, then a non-zero leading digit (country
|
||||
// codes never start with 0), followed by 6–14 more digits — total 7–15
|
||||
// significant digits. Spaces / dashes / parens are tolerated by stripping
|
||||
// before validation, but local-format numbers like "0244…" or "07700…"
|
||||
// are rejected here so the host fixes them at upload time rather than at
|
||||
// WhatsApp-send time.
|
||||
var phoneRe = regexp.MustCompile(`^\+?[1-9][0-9]{6,14}$`)
|
||||
|
||||
// Parse reads a CSV from r and returns the parsed result. Encoding is
|
||||
// auto-detected: UTF-8 with or without BOM, plus UTF-16 LE/BE BOMs
|
||||
// (commonly produced by Mac Numbers exports).
|
||||
func Parse(r io.Reader, opt Options) (*Result, error) {
|
||||
max := opt.MaxRows
|
||||
if max <= 0 {
|
||||
max = DefaultMaxRows
|
||||
}
|
||||
|
||||
rd, err := decodingReader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
csvr := csv.NewReader(rd)
|
||||
csvr.FieldsPerRecord = -1 // tolerate ragged rows; we re-validate column count ourselves
|
||||
csvr.TrimLeadingSpace = true
|
||||
|
||||
header, err := csvr.Read()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil, errors.New("csv is empty")
|
||||
}
|
||||
return nil, fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
cols, err := detectColumns(header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &Result{Rows: make([]Row, 0, 64)}
|
||||
lineNo := 1 // header was line 1
|
||||
for {
|
||||
rec, err := csvr.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
lineNo++
|
||||
if err != nil {
|
||||
out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: fmt.Sprintf("malformed csv: %v", err)})
|
||||
continue
|
||||
}
|
||||
out.TotalCount++
|
||||
if out.TotalCount > max {
|
||||
return nil, fmt.Errorf("import exceeds maximum of %d rows", max)
|
||||
}
|
||||
|
||||
// Skip fully-empty rows silently — these appear at the end of
|
||||
// Excel exports a lot.
|
||||
if rowEmpty(rec) {
|
||||
out.TotalCount-- // don't count it
|
||||
continue
|
||||
}
|
||||
|
||||
row, rerr := buildRow(rec, cols)
|
||||
if rerr != "" {
|
||||
out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: rerr})
|
||||
continue
|
||||
}
|
||||
out.Rows = append(out.Rows, row)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func rowEmpty(rec []string) bool {
|
||||
for _, v := range rec {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// decodingReader strips a UTF-8 BOM and decodes UTF-16 LE/BE when their
|
||||
// BOM is present, returning a UTF-8 reader. Other byte orders fall through
|
||||
// as raw UTF-8.
|
||||
func decodingReader(r io.Reader) (*bufio.Reader, error) {
|
||||
br := bufio.NewReader(r)
|
||||
bom, err := br.Peek(3)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case len(bom) >= 3 && bom[0] == 0xEF && bom[1] == 0xBB && bom[2] == 0xBF:
|
||||
_, _ = br.Discard(3)
|
||||
return br, nil
|
||||
case len(bom) >= 2 && bom[0] == 0xFF && bom[1] == 0xFE:
|
||||
_, _ = br.Discard(2)
|
||||
dec := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()
|
||||
return bufio.NewReader(transform.NewReader(br, dec)), nil
|
||||
case len(bom) >= 2 && bom[0] == 0xFE && bom[1] == 0xFF:
|
||||
_, _ = br.Discard(2)
|
||||
dec := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder()
|
||||
return bufio.NewReader(transform.NewReader(br, dec)), nil
|
||||
}
|
||||
return br, nil
|
||||
}
|
||||
|
||||
// columnSet records which column index each known field lives in. -1 means
|
||||
// the column was not supplied; only Name is mandatory.
|
||||
type columnSet struct {
|
||||
name, email, phone, plusOnes int
|
||||
}
|
||||
|
||||
func detectColumns(header []string) (columnSet, error) {
|
||||
cs := columnSet{name: -1, email: -1, phone: -1, plusOnes: -1}
|
||||
for i, raw := range header {
|
||||
key := normaliseHeader(raw)
|
||||
switch key {
|
||||
case "name", "guestname", "fullname":
|
||||
cs.name = i
|
||||
case "email", "emailaddress", "e-mail":
|
||||
cs.email = i
|
||||
case "phone", "telephone", "mobile", "phonenumber":
|
||||
cs.phone = i
|
||||
case "plusones", "plus1", "plus-one", "plus-ones", "+1", "guests", "additionalguests":
|
||||
cs.plusOnes = i
|
||||
}
|
||||
}
|
||||
if cs.name < 0 {
|
||||
return cs, fmt.Errorf("required column 'name' not found in header: %v", header)
|
||||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func normaliseHeader(s string) string {
|
||||
s = strings.ToLower(strings.TrimSpace(s))
|
||||
// Drop spaces + underscores. Keep `+`, `-` so "+1" / "plus-one" still
|
||||
// match exactly.
|
||||
return strings.NewReplacer(" ", "", "_", "").Replace(s)
|
||||
}
|
||||
|
||||
func buildRow(rec []string, cs columnSet) (Row, string) {
|
||||
get := func(i int) string {
|
||||
if i < 0 || i >= len(rec) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(rec[i])
|
||||
}
|
||||
row := Row{
|
||||
Name: get(cs.name),
|
||||
Email: strings.ToLower(get(cs.email)),
|
||||
Phone: get(cs.phone),
|
||||
}
|
||||
if row.Name == "" {
|
||||
return row, "name is required"
|
||||
}
|
||||
|
||||
if row.Email != "" {
|
||||
if _, err := mail.ParseAddress(row.Email); err != nil {
|
||||
return row, "invalid email"
|
||||
}
|
||||
}
|
||||
if row.Phone != "" {
|
||||
stripped := stripPhone(row.Phone)
|
||||
if !phoneRe.MatchString(stripped) {
|
||||
return row, "phone must be in international format with country code (e.g. +447700900123) — local numbers starting with 0 won't work for SMS or WhatsApp"
|
||||
}
|
||||
// Normalise: ensure stored form always starts with "+".
|
||||
if !strings.HasPrefix(stripped, "+") {
|
||||
stripped = "+" + stripped
|
||||
}
|
||||
row.Phone = stripped
|
||||
}
|
||||
if raw := get(cs.plusOnes); raw != "" {
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 0 {
|
||||
return row, "plus_ones must be a non-negative integer"
|
||||
}
|
||||
row.PlusOnes = n
|
||||
}
|
||||
return row, ""
|
||||
}
|
||||
|
||||
var phoneStripper = strings.NewReplacer(" ", "", "-", "", "(", "", ")", "", " ", "")
|
||||
|
||||
func stripPhone(s string) string {
|
||||
return phoneStripper.Replace(s)
|
||||
}
|
||||
|
||||
// TemplateCSV is the sample file served at /events/{id}/guests/import/template.
|
||||
// Phone numbers MUST include the country code (e.g. +44 for UK, +233 for
|
||||
// Ghana). Local-format numbers like "0244..." or "07700..." will be
|
||||
// rejected at upload — the sample below shows the expected shape.
|
||||
const TemplateCSV = "name,email,phone,plus_ones\n" +
|
||||
"Alex Doe,alex@example.com,+447700900123,1\n" +
|
||||
"Sam Patel,sam@example.com,,0\n" +
|
||||
"Jordan Lee,,+15551234567,2\n" +
|
||||
"Mira Patel,mira@example.com,+233244123456,0\n"
|
||||
@@ -0,0 +1,157 @@
|
||||
package csvimport
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHappyPath(t *testing.T) {
|
||||
in := `name,email,phone,plus_ones
|
||||
Alex Doe,alex@example.com,+447700900123,1
|
||||
Sam Patel,SAM@example.com,,0
|
||||
Jordan Lee,,+1 (555) 123-4567,2
|
||||
`
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if got, want := len(r.Rows), 3; got != want {
|
||||
t.Fatalf("rows: got %d want %d (errors=%+v)", got, want, r.Errors)
|
||||
}
|
||||
if r.Rows[1].Email != "sam@example.com" {
|
||||
t.Errorf("email not lowercased: %q", r.Rows[1].Email)
|
||||
}
|
||||
if r.Rows[2].Phone != "+15551234567" {
|
||||
t.Errorf("phone not stripped: %q", r.Rows[2].Phone)
|
||||
}
|
||||
if r.Rows[0].Phone != "+447700900123" {
|
||||
t.Errorf("phone should keep leading +: %q", r.Rows[0].Phone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePhoneNormalisedToPlus(t *testing.T) {
|
||||
// E.164 without explicit "+" is accepted and normalised to include one.
|
||||
in := "name,phone\nAlex,447700900123\n"
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 || r.Rows[0].Phone != "+447700900123" {
|
||||
t.Fatalf("expected normalised phone, got: rows=%+v errors=%+v", r.Rows, r.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePhoneRejectsLocalFormat(t *testing.T) {
|
||||
// Local UK / GH style numbers (leading 0, no country code) must be
|
||||
// rejected — they break SMS routing and WhatsApp click-to-chat.
|
||||
in := `name,phone
|
||||
UK Local,07700900123
|
||||
GH Local,0244123456
|
||||
With Plus And Zero,+0244123456
|
||||
`
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Errors) != 3 {
|
||||
t.Fatalf("expected 3 errors for local-format phones, got %d: %+v", len(r.Errors), r.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHeaderVariants(t *testing.T) {
|
||||
cases := []string{
|
||||
"Name,Email,Phone,Plus Ones\nMira,m@x.com,,1\n",
|
||||
"Guest Name,E-Mail,Telephone,+1\nMira,m@x.com,,1\n",
|
||||
"full_name,email_address,mobile,plusones\nMira,m@x.com,,1\n",
|
||||
}
|
||||
for i, in := range cases {
|
||||
t.Run(string(rune('a'+i)), func(t *testing.T) {
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 {
|
||||
t.Fatalf("rows=%d errors=%+v", len(r.Rows), r.Errors)
|
||||
}
|
||||
if r.Rows[0].PlusOnes != 1 {
|
||||
t.Errorf("plusones not detected: %+v", r.Rows[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRowValidation(t *testing.T) {
|
||||
in := `name,email,phone,plus_ones
|
||||
,a@x.com,,1
|
||||
Valid Guest,not-an-email,,0
|
||||
Phone Person,,abc,0
|
||||
Negative,,,-1
|
||||
Email Only,e@x.com,,
|
||||
`
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Errors) != 4 {
|
||||
t.Fatalf("expected 4 errors, got %d: %+v", len(r.Errors), r.Errors)
|
||||
}
|
||||
// "Email Only" row is valid: name present, email parses, phone+plus blank.
|
||||
if len(r.Rows) != 1 || r.Rows[0].Name != "Email Only" {
|
||||
t.Fatalf("expected 1 valid row, got %+v", r.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUTF8BOM(t *testing.T) {
|
||||
in := "\xEF\xBB\xBFname,email\nMira,m@x.com\n"
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" {
|
||||
t.Fatalf("BOM not stripped: %+v", r.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUTF16LE(t *testing.T) {
|
||||
// "name,email\nMira,m@x.com\n" in UTF-16 LE with BOM.
|
||||
in := []byte{0xFF, 0xFE}
|
||||
for _, r := range "name,email\nMira,m@x.com\n" {
|
||||
in = append(in, byte(r), byte(r>>8))
|
||||
}
|
||||
r, err := Parse(strings.NewReader(string(in)), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" {
|
||||
t.Fatalf("UTF-16 LE not decoded: %+v", r.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEmptyTrailingRows(t *testing.T) {
|
||||
in := "name,email\nMira,m@x.com\n,,\n\n,\n"
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if r.TotalCount != 1 {
|
||||
t.Fatalf("trailing blanks counted: TotalCount=%d", r.TotalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMissingNameHeader(t *testing.T) {
|
||||
in := "email,phone\na@x.com,\n"
|
||||
if _, err := Parse(strings.NewReader(in), Options{}); err == nil {
|
||||
t.Fatal("expected error for missing name column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMaxRows(t *testing.T) {
|
||||
var b strings.Builder
|
||||
b.WriteString("name\n")
|
||||
for i := 0; i < 11; i++ {
|
||||
b.WriteString("X\n")
|
||||
}
|
||||
if _, err := Parse(strings.NewReader(b.String()), Options{MaxRows: 10}); err == nil {
|
||||
t.Fatal("expected error when exceeding MaxRows")
|
||||
}
|
||||
}
|
||||
+26
-7
@@ -8,14 +8,33 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PasswordHash string `json:"-"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
|
||||
DeletedAt *time.Time `json:"-"`
|
||||
TermsAcceptedAt *time.Time `json:"terms_accepted_at,omitempty"`
|
||||
PrivacyPolicyAcceptedAt *time.Time `json:"privacy_policy_accepted_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TermsAccepted reports whether the user has accepted both the terms
|
||||
// of service and the privacy policy. Both must be present for the user
|
||||
// to use the dashboard once enforcement is enabled.
|
||||
func (u *User) TermsAccepted() bool {
|
||||
return u != nil && u.TermsAcceptedAt != nil && u.PrivacyPolicyAcceptedAt != nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrEmailTaken = errors.New("email already in use")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrEmailTaken = errors.New("email already in use")
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
ErrAuthTokenNotFound = errors.New("auth token not found")
|
||||
ErrAuthTokenConsumed = errors.New("auth token already used")
|
||||
ErrAuthTokenExpired = errors.New("auth token expired")
|
||||
ErrRefreshTokenRevoked = errors.New("refresh token revoked")
|
||||
ErrAccountLocked = errors.New("account locked due to too many failed login attempts")
|
||||
)
|
||||
|
||||
@@ -92,6 +92,15 @@ func (c *Client) PublishRSVPConfirmed(ctx context.Context, evt RSVPConfirmed) er
|
||||
return c.publishJSON(ctx, SubjectRSVPConfirmed, evt, evt.RSVPID)
|
||||
}
|
||||
|
||||
func (c *Client) PublishInvitationSend(ctx context.Context, evt InvitationSend) error {
|
||||
if evt.IssuedAt.IsZero() {
|
||||
evt.IssuedAt = time.Now().UTC()
|
||||
}
|
||||
// Dedup by token id — re-issuing a token (currently disallowed by the
|
||||
// unique constraint, but defensive) won't double-send the email.
|
||||
return c.publishJSON(ctx, SubjectInvitationSend, evt, evt.TokenID)
|
||||
}
|
||||
|
||||
func (c *Client) publishJSON(ctx context.Context, subject string, payload any, dedupeKey uuid.UUID) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -38,3 +38,20 @@ type RSVPConfirmed struct {
|
||||
RiskScore *int `json:"risk_score,omitempty"`
|
||||
SubmittedAt time.Time `json:"submitted_at"`
|
||||
}
|
||||
|
||||
// InvitationSend asks the notifier to dispatch a guest invitation email.
|
||||
// Carries everything the email template needs so the worker doesn't have
|
||||
// to re-fetch event/guest details from Postgres on every send.
|
||||
type InvitationSend struct {
|
||||
EventID uuid.UUID `json:"event_id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
GuestName string `json:"guest_name"`
|
||||
GuestEmail string `json:"guest_email"`
|
||||
HostName string `json:"host_name"`
|
||||
EventName string `json:"event_name"`
|
||||
Venue string `json:"venue,omitempty"`
|
||||
EventDate time.Time `json:"event_date"`
|
||||
Link string `json:"link"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package natspub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
)
|
||||
|
||||
type InvitationSendHandler func(ctx context.Context, evt InvitationSend) error
|
||||
|
||||
type InvitationSendSubscriber struct {
|
||||
logger *slog.Logger
|
||||
consumer jetstream.Consumer
|
||||
handler InvitationSendHandler
|
||||
}
|
||||
|
||||
func NewInvitationSendSubscriber(
|
||||
ctx context.Context,
|
||||
c *Client,
|
||||
durable string,
|
||||
handler InvitationSendHandler,
|
||||
logger *slog.Logger,
|
||||
) (*InvitationSendSubscriber, error) {
|
||||
cons, err := c.js.CreateOrUpdateConsumer(ctx, StreamName, jetstream.ConsumerConfig{
|
||||
Durable: durable,
|
||||
Name: durable,
|
||||
FilterSubject: SubjectInvitationSend,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
MaxDeliver: 5,
|
||||
AckWait: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create consumer %s: %w", durable, err)
|
||||
}
|
||||
return &InvitationSendSubscriber{logger: logger, consumer: cons, handler: handler}, nil
|
||||
}
|
||||
|
||||
func (s *InvitationSendSubscriber) Start(ctx context.Context) (jetstream.ConsumeContext, error) {
|
||||
cc, err := s.consumer.Consume(func(msg jetstream.Msg) {
|
||||
var evt InvitationSend
|
||||
if err := json.Unmarshal(msg.Data(), &evt); err != nil {
|
||||
s.logger.Error("decode invitation.send", "err", err)
|
||||
_ = msg.Term()
|
||||
return
|
||||
}
|
||||
hctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.handler(hctx, evt); err != nil {
|
||||
s.logger.Error("handle invitation.send", "err", err)
|
||||
_ = msg.NakWithDelay(5 * time.Second)
|
||||
return
|
||||
}
|
||||
_ = msg.Ack()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("consume: %w", err)
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sesv2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sesv2/types"
|
||||
)
|
||||
|
||||
// SESConfig is the surface area for picking an SES sender. ConfigurationSet
|
||||
// is optional but recommended in production — it's where bounce + complaint
|
||||
// SNS topics get wired so the webhook handler has something to consume.
|
||||
type SESConfig struct {
|
||||
Region string
|
||||
FromEmail string
|
||||
FromName string
|
||||
ConfigurationSet string
|
||||
PublicBaseURL string // for unsubscribe links in templates
|
||||
}
|
||||
|
||||
// SESEmailSender sends transactional emails (verification + reset for the
|
||||
// auth flows, plus invitation/confirmation/reminder for guests) via Amazon
|
||||
// SESv2. The same client serves both audiences so callers don't end up
|
||||
// with two SES configurations to maintain.
|
||||
type SESEmailSender struct {
|
||||
client *sesv2.Client
|
||||
tpls *Templates
|
||||
from string
|
||||
configSet *string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewSESEmailSender returns a configured SES sender, or an error if the
|
||||
// AWS SDK can't bootstrap. The caller typically constructs this once at
|
||||
// startup and reuses it for the lifetime of the process.
|
||||
func NewSESEmailSender(ctx context.Context, cfg SESConfig, tpls *Templates) (*SESEmailSender, error) {
|
||||
if cfg.FromEmail == "" {
|
||||
return nil, fmt.Errorf("ses: FromEmail required")
|
||||
}
|
||||
if cfg.Region == "" {
|
||||
cfg.Region = "us-east-1"
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(cfg.Region))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ses: load aws config: %w", err)
|
||||
}
|
||||
client := sesv2.NewFromConfig(awsCfg)
|
||||
|
||||
from := cfg.FromEmail
|
||||
if cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
|
||||
}
|
||||
var cs *string
|
||||
if cfg.ConfigurationSet != "" {
|
||||
cs = aws.String(cfg.ConfigurationSet)
|
||||
}
|
||||
|
||||
return &SESEmailSender{
|
||||
client: client,
|
||||
tpls: tpls,
|
||||
from: from,
|
||||
configSet: cs,
|
||||
baseURL: strings.TrimRight(cfg.PublicBaseURL, "/"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- auth.EmailSender implementation ---
|
||||
|
||||
// SendVerification renders the verification template and posts it to SES.
|
||||
func (s *SESEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Verify your GuestGuard email",
|
||||
TmplVerification, map[string]any{
|
||||
"Name": name,
|
||||
"Link": link,
|
||||
})
|
||||
}
|
||||
|
||||
// SendPasswordReset renders the reset template and posts it to SES.
|
||||
func (s *SESEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Reset your GuestGuard password",
|
||||
TmplPasswordReset, map[string]any{
|
||||
"Name": name,
|
||||
"Link": link,
|
||||
"ExpiryHumane": "1 hour",
|
||||
})
|
||||
}
|
||||
|
||||
// SendGuest is used by the notifier worker for invitation / confirmation /
|
||||
// reminder emails — anything addressed at a guest.
|
||||
func (s *SESEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error) {
|
||||
return s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func (s *SESEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error {
|
||||
_, err := s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SESEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
html, text, err := s.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, err := s.client.SendEmail(ctx, &sesv2.SendEmailInput{
|
||||
FromEmailAddress: aws.String(s.from),
|
||||
Destination: &types.Destination{ToAddresses: []string{to}},
|
||||
ConfigurationSetName: s.configSet,
|
||||
Content: &types.EmailContent{
|
||||
Simple: &types.Message{
|
||||
Subject: &types.Content{Data: aws.String(subject), Charset: aws.String("UTF-8")},
|
||||
Body: &types.Body{
|
||||
Html: &types.Content{Data: aws.String(html), Charset: aws.String("UTF-8")},
|
||||
Text: &types.Content{Data: aws.String(text), Charset: aws.String("UTF-8")},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ses: send: %w", err)
|
||||
}
|
||||
if out.MessageId == nil {
|
||||
return "", nil
|
||||
}
|
||||
return *out.MessageId, nil
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// EmailBackend names the chosen email delivery channel for telemetry +
|
||||
// startup logging. Mostly a debugging aid — code paths don't branch on
|
||||
// this value.
|
||||
type EmailBackend string
|
||||
|
||||
const (
|
||||
BackendResend EmailBackend = "resend"
|
||||
BackendSMTP EmailBackend = "smtp"
|
||||
BackendSES EmailBackend = "ses"
|
||||
BackendLog EmailBackend = "log"
|
||||
)
|
||||
|
||||
// EmailSenderConfig collects every email-related env var so the picker
|
||||
// has a single, ordered place to decide which backend wins. Priority is
|
||||
// Resend > SMTP > SES > Log — the first one with non-empty creds is used.
|
||||
type EmailSenderConfig struct {
|
||||
Resend ResendConfig
|
||||
SMTP SMTPConfig
|
||||
SES SESConfig
|
||||
}
|
||||
|
||||
// CombinedEmailSender satisfies both the auth.EmailSender interface (for
|
||||
// verification + reset emails) and GuestEmailDispatcher (for invitation,
|
||||
// confirmation, reminder). One concrete value handles both audiences so
|
||||
// callers don't end up with two configurations.
|
||||
type CombinedEmailSender interface {
|
||||
SendVerification(ctx context.Context, to, name, link string) error
|
||||
SendPasswordReset(ctx context.Context, to, name, link string) error
|
||||
SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error)
|
||||
}
|
||||
|
||||
// PickEmailSender returns the configured email sender + which backend was
|
||||
// chosen. Falls back to a logger stub if nothing is configured, so the
|
||||
// service stays bootable in stripped-down dev environments.
|
||||
func PickEmailSender(ctx context.Context, cfg EmailSenderConfig, tpls *Templates, logger *slog.Logger) (CombinedEmailSender, EmailBackend, error) {
|
||||
switch {
|
||||
case cfg.Resend.APIKey != "":
|
||||
s, err := NewResendEmailSender(cfg.Resend, tpls)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return s, BackendResend, nil
|
||||
case cfg.SMTP.Host != "":
|
||||
s, err := NewSMTPEmailSender(cfg.SMTP, tpls)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return s, BackendSMTP, nil
|
||||
case cfg.SES.FromEmail != "":
|
||||
s, err := NewSESEmailSender(ctx, cfg.SES, tpls)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return s, BackendSES, nil
|
||||
}
|
||||
return &logCombinedSender{logger: logger, tpls: tpls}, BackendLog, nil
|
||||
}
|
||||
|
||||
// logCombinedSender is the dev fallback. Verification + reset emails come
|
||||
// through as structured log lines (preserving the Block A behaviour);
|
||||
// guest emails get rendered + dumped so engineers can eyeball the output.
|
||||
type logCombinedSender struct {
|
||||
logger *slog.Logger
|
||||
tpls *Templates
|
||||
}
|
||||
|
||||
func (l *logCombinedSender) SendVerification(_ context.Context, to, name, link string) error {
|
||||
l.logger.Info("auth email (stub): verification", "to", to, "name", name, "link", link)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logCombinedSender) SendPasswordReset(_ context.Context, to, name, link string) error {
|
||||
l.logger.Info("auth email (stub): password reset", "to", to, "name", name, "link", link)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logCombinedSender) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
_, text, err := l.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
l.logger.Info("guest email (stub)",
|
||||
"to", to, "subject", subject, "template", string(name), "text_body", text,
|
||||
)
|
||||
return "log:" + string(name), nil
|
||||
}
|
||||
@@ -64,12 +64,13 @@ func NewRepo(db *storage.DB) *Repo {
|
||||
}
|
||||
|
||||
type RecordParams struct {
|
||||
GuestID uuid.UUID
|
||||
Channel Channel
|
||||
Type Type
|
||||
Status Status
|
||||
ProviderID string
|
||||
Error string
|
||||
GuestID uuid.UUID
|
||||
Channel Channel
|
||||
Type Type
|
||||
Status Status
|
||||
ProviderID string // human-friendly id (e.g. "log:xyz")
|
||||
ProviderMessageID string // provider's message id (Twilio SID, SES MessageId)
|
||||
Error string
|
||||
}
|
||||
|
||||
func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
@@ -77,6 +78,10 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
if p.ProviderID != "" {
|
||||
providerID = &p.ProviderID
|
||||
}
|
||||
var providerMsgID *string
|
||||
if p.ProviderMessageID != "" {
|
||||
providerMsgID = &p.ProviderMessageID
|
||||
}
|
||||
var errStr *string
|
||||
if p.Error != "" {
|
||||
errStr = &p.Error
|
||||
@@ -90,14 +95,15 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
|
||||
const q = `
|
||||
INSERT INTO notifications (guest_id, channel, type, status, provider_id,
|
||||
attempts, last_attempt, delivered_at, error)
|
||||
VALUES ($1, $2, $3, $4, $5, 1, now(), $6, $7)
|
||||
provider_message_id, attempts, last_attempt,
|
||||
delivered_at, error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 1, now(), $7, $8)
|
||||
RETURNING id
|
||||
`
|
||||
var id uuid.UUID
|
||||
err := r.pool.QueryRow(ctx, q,
|
||||
p.GuestID, string(p.Channel), string(p.Type), string(p.Status),
|
||||
providerID, deliveredAt, errStr,
|
||||
providerID, providerMsgID, deliveredAt, errStr,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("record notification: %w", err)
|
||||
@@ -105,6 +111,35 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// MarkBounce records a bounce on the notification row identified by the
|
||||
// provider's message id. Called from webhook handlers.
|
||||
func (r *Repo) MarkBounce(ctx context.Context, providerMessageID, bounceType string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE notifications
|
||||
SET status = 'bounced', bounce_type = $2, error = COALESCE(error, '')
|
||||
WHERE provider_message_id = $1
|
||||
`, providerMessageID, bounceType)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkComplaint records a complaint (spam report) for the same row.
|
||||
func (r *Repo) MarkComplaint(ctx context.Context, providerMessageID string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE notifications SET complained = TRUE WHERE provider_message_id = $1
|
||||
`, providerMessageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkDelivered moves a row from 'sent' to 'delivered' when the provider's
|
||||
// delivery status webhook fires.
|
||||
func (r *Repo) MarkDelivered(ctx context.Context, providerMessageID string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE notifications SET status = 'delivered', delivered_at = now()
|
||||
WHERE provider_message_id = $1 AND status NOT IN ('bounced','failed')
|
||||
`, providerMessageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// LogSender pretends to send and just logs. Useful for Phase 3 demos and
|
||||
// tests; concrete providers (Twilio/SES) plug in later.
|
||||
type LogSender struct{}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResendConfig configures the Resend HTTP sender. APIKey is required;
|
||||
// FromEmail is the sending address on a domain you've verified in the
|
||||
// Resend dashboard.
|
||||
type ResendConfig struct {
|
||||
APIKey string
|
||||
FromEmail string
|
||||
FromName string
|
||||
|
||||
// HTTPClient overrides the default client (mostly so tests can point
|
||||
// at httptest.Server). Leave nil for production.
|
||||
HTTPClient *http.Client
|
||||
BaseURL string // overrideable for tests; defaults to https://api.resend.com
|
||||
}
|
||||
|
||||
// ResendEmailSender posts emails through https://api.resend.com/emails.
|
||||
// Implements auth.EmailSender + GuestEmailDispatcher.
|
||||
type ResendEmailSender struct {
|
||||
cfg ResendConfig
|
||||
tpls *Templates
|
||||
from string
|
||||
client *http.Client
|
||||
url string
|
||||
}
|
||||
|
||||
func NewResendEmailSender(cfg ResendConfig, tpls *Templates) (*ResendEmailSender, error) {
|
||||
if cfg.APIKey == "" {
|
||||
return nil, errors.New("resend: APIKey required")
|
||||
}
|
||||
if cfg.FromEmail == "" {
|
||||
return nil, errors.New("resend: FromEmail required")
|
||||
}
|
||||
from := cfg.FromEmail
|
||||
if cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
|
||||
}
|
||||
cli := cfg.HTTPClient
|
||||
if cli == nil {
|
||||
cli = &http.Client{Timeout: 15 * time.Second}
|
||||
}
|
||||
base := cfg.BaseURL
|
||||
if base == "" {
|
||||
base = "https://api.resend.com"
|
||||
}
|
||||
return &ResendEmailSender{cfg: cfg, tpls: tpls, from: from, client: cli, url: base + "/emails"}, nil
|
||||
}
|
||||
|
||||
// --- auth.EmailSender ---
|
||||
|
||||
func (s *ResendEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
|
||||
_, err := s.sendTemplated(ctx, to, "Verify your GuestGuard email",
|
||||
TmplVerification, map[string]any{"Name": name, "Link": link})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ResendEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
|
||||
_, err := s.sendTemplated(ctx, to, "Reset your GuestGuard password",
|
||||
TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"})
|
||||
return err
|
||||
}
|
||||
|
||||
// --- GuestEmailDispatcher ---
|
||||
|
||||
func (s *ResendEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
return s.sendTemplated(ctx, to, subject, name, data)
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
type resendRequest struct {
|
||||
From string `json:"from"`
|
||||
To []string `json:"to"`
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type resendResponse struct {
|
||||
ID string `json:"id"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ResendEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
html, text, err := s.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(resendRequest{
|
||||
From: s.from,
|
||||
To: []string{to},
|
||||
Subject: subject,
|
||||
HTML: html,
|
||||
Text: text,
|
||||
})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.cfg.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resend: do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
if resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("resend: status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
var parsed resendResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return "", fmt.Errorf("resend: parse: %w", err)
|
||||
}
|
||||
return parsed.ID, nil
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResendSenderHappyPath(t *testing.T) {
|
||||
tpls, err := NewTemplates()
|
||||
if err != nil {
|
||||
t.Fatalf("NewTemplates: %v", err)
|
||||
}
|
||||
|
||||
var gotPath, gotAuth, gotContentType string
|
||||
var gotBody map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotContentType = r.Header.Get("Content-Type")
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(b, &gotBody)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"id": "resend-abc-123"})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
s, err := NewResendEmailSender(ResendConfig{
|
||||
APIKey: "secret-key",
|
||||
FromEmail: "no-reply@example.test",
|
||||
FromName: "GuestGuard",
|
||||
BaseURL: srv.URL,
|
||||
}, tpls)
|
||||
if err != nil {
|
||||
t.Fatalf("NewResendEmailSender: %v", err)
|
||||
}
|
||||
|
||||
id, err := s.SendGuest(context.Background(), "to@example.test", "You're invited",
|
||||
TmplInvitation, map[string]any{
|
||||
"GuestName": "Mira", "HostName": "Kay", "EventName": "Beach Day",
|
||||
"Link": "https://example.test/rsvp/x",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendGuest: %v", err)
|
||||
}
|
||||
if id != "resend-abc-123" {
|
||||
t.Fatalf("provider id: got %q want %q", id, "resend-abc-123")
|
||||
}
|
||||
if gotPath != "/emails" {
|
||||
t.Errorf("path: got %q want /emails", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer secret-key" {
|
||||
t.Errorf("auth: got %q", gotAuth)
|
||||
}
|
||||
if gotContentType != "application/json" {
|
||||
t.Errorf("content-type: got %q", gotContentType)
|
||||
}
|
||||
if gotBody["from"] != "GuestGuard <no-reply@example.test>" {
|
||||
t.Errorf("from: got %v", gotBody["from"])
|
||||
}
|
||||
if !strings.Contains(gotBody["html"].(string), "Beach Day") {
|
||||
t.Errorf("html missing event name: %v", gotBody["html"])
|
||||
}
|
||||
if !strings.Contains(gotBody["text"].(string), "Mira") {
|
||||
t.Errorf("text missing guest name: %v", gotBody["text"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResendSenderErrorPropagates(t *testing.T) {
|
||||
tpls, _ := NewTemplates()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"message":"invalid api key"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
s, err := NewResendEmailSender(ResendConfig{
|
||||
APIKey: "bad", FromEmail: "x@y.test", BaseURL: srv.URL,
|
||||
}, tpls)
|
||||
if err != nil {
|
||||
t.Fatalf("NewResendEmailSender: %v", err)
|
||||
}
|
||||
if err := s.SendVerification(context.Background(), "z@y.test", "Z", "https://link"); err == nil {
|
||||
t.Fatal("expected error on 401")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Router dispatches an OutboundMessage to the channel-appropriate sender.
|
||||
// The notifier worker holds one Router and stays oblivious to whether the
|
||||
// concrete backend is SES, Twilio, or a logger stub.
|
||||
type Router struct {
|
||||
email Sender
|
||||
sms Sender
|
||||
}
|
||||
|
||||
func NewRouter(email, sms Sender) *Router {
|
||||
return &Router{email: email, sms: sms}
|
||||
}
|
||||
|
||||
func (r *Router) Send(ctx context.Context, msg OutboundMessage) (string, error) {
|
||||
switch msg.Channel {
|
||||
case ChannelEmail:
|
||||
if r.email == nil {
|
||||
return "", fmt.Errorf("no email sender configured")
|
||||
}
|
||||
return r.email.Send(ctx, msg)
|
||||
case ChannelSMS:
|
||||
if r.sms == nil {
|
||||
return "", fmt.Errorf("no sms sender configured")
|
||||
}
|
||||
return r.sms.Send(ctx, msg)
|
||||
}
|
||||
return "", fmt.Errorf("router: unknown channel %q", msg.Channel)
|
||||
}
|
||||
|
||||
// LogGuestEmailDispatcher is the dev-mode dispatcher that renders the
|
||||
// templated email and logs both bodies. Useful for local docker-compose
|
||||
// before SES is configured — gives engineers the rendered output without
|
||||
// needing a real inbox.
|
||||
type LogGuestEmailDispatcher struct {
|
||||
logger *slog.Logger
|
||||
tpls *Templates
|
||||
}
|
||||
|
||||
func NewLogGuestEmailDispatcher(logger *slog.Logger, tpls *Templates) *LogGuestEmailDispatcher {
|
||||
return &LogGuestEmailDispatcher{logger: logger, tpls: tpls}
|
||||
}
|
||||
|
||||
func (d *LogGuestEmailDispatcher) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
_, text, err := d.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := "log:" + uuid.NewString()
|
||||
d.logger.Info("guest email (stub)",
|
||||
"to", to,
|
||||
"subject", subject,
|
||||
"template", string(name),
|
||||
"provider_message_id", id,
|
||||
"text_body", text,
|
||||
)
|
||||
return id, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GuestEmailDispatcher is the abstraction the notifier uses to send a
|
||||
// templated email to a single guest. Both SES (production) and Log (dev)
|
||||
// satisfy it.
|
||||
type GuestEmailDispatcher interface {
|
||||
SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error)
|
||||
}
|
||||
|
||||
// EmailSender is the notification.Sender for ChannelEmail. It maps the
|
||||
// generic OutboundMessage shape used by the notifier worker onto our
|
||||
// templated SES / Log path, and honours the suppression list (a guest who
|
||||
// unsubscribed must never receive another email).
|
||||
type EmailSender struct {
|
||||
dispatcher GuestEmailDispatcher
|
||||
suppression *SuppressionRepo
|
||||
}
|
||||
|
||||
func NewEmailSender(d GuestEmailDispatcher, s *SuppressionRepo) *EmailSender {
|
||||
return &EmailSender{dispatcher: d, suppression: s}
|
||||
}
|
||||
|
||||
func (e *EmailSender) Send(ctx context.Context, msg OutboundMessage) (string, error) {
|
||||
if msg.GuestID == uuid.Nil {
|
||||
return "", errors.New("missing guest id")
|
||||
}
|
||||
if msg.Channel != ChannelEmail {
|
||||
return "", fmt.Errorf("EmailSender does not handle channel %q", msg.Channel)
|
||||
}
|
||||
to, _ := msg.Metadata["to"].(string)
|
||||
if to == "" {
|
||||
return "", errors.New("email recipient missing from metadata.to")
|
||||
}
|
||||
|
||||
if e.suppression != nil {
|
||||
yep, err := e.suppression.IsSuppressed(ctx, to)
|
||||
if err != nil {
|
||||
// On lookup error, fail safe by NOT sending — better than
|
||||
// emailing an unsubscribed address.
|
||||
return "", fmt.Errorf("suppression lookup: %w", err)
|
||||
}
|
||||
if yep {
|
||||
return "suppressed:" + to, nil
|
||||
}
|
||||
}
|
||||
|
||||
tmpl := templateForType(msg.Type)
|
||||
if tmpl == "" {
|
||||
return "", fmt.Errorf("no template for type %q", msg.Type)
|
||||
}
|
||||
subject := msg.Subject
|
||||
if subject == "" {
|
||||
subject = defaultSubject(msg.Type, msg.Metadata)
|
||||
}
|
||||
data := map[string]any{}
|
||||
for k, v := range msg.Metadata {
|
||||
data[k] = v
|
||||
}
|
||||
return e.dispatcher.SendGuest(ctx, to, subject, tmpl, data)
|
||||
}
|
||||
|
||||
func templateForType(t Type) TemplateName {
|
||||
switch t {
|
||||
case TypeInvitation:
|
||||
return TmplInvitation
|
||||
case TypeConfirmation:
|
||||
return TmplConfirmation
|
||||
case TypeReminder:
|
||||
return TmplReminder
|
||||
case TypeVerification:
|
||||
return TmplVerification
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func defaultSubject(t Type, meta map[string]any) string {
|
||||
event, _ := meta["EventName"].(string)
|
||||
switch t {
|
||||
case TypeInvitation:
|
||||
if event != "" {
|
||||
return "You're invited — " + event
|
||||
}
|
||||
return "You're invited"
|
||||
case TypeConfirmation:
|
||||
if event != "" {
|
||||
return "RSVP confirmed — " + event
|
||||
}
|
||||
return "RSVP confirmed"
|
||||
case TypeReminder:
|
||||
if event != "" {
|
||||
return "Reminder: " + event
|
||||
}
|
||||
return "Reminder"
|
||||
}
|
||||
return "GuestGuard"
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/twilio/twilio-go"
|
||||
openapi "github.com/twilio/twilio-go/rest/api/v2010"
|
||||
)
|
||||
|
||||
// TwilioConfig configures the SMS sender. AccountSID + AuthToken pair
|
||||
// authenticate the REST client; FromNumber is the verified Twilio number.
|
||||
type TwilioConfig struct {
|
||||
AccountSID string
|
||||
AuthToken string
|
||||
FromNumber string
|
||||
MaxAttempts int
|
||||
}
|
||||
|
||||
// TwilioSender implements notification.Sender for ChannelSMS. Retries on
|
||||
// transient errors with exponential backoff (1s, 5s, 30s, 5m, 30m). Twilio
|
||||
// surfaces a numeric ErrorCode for permanent failures (e.g. 21610
|
||||
// unsubscribed, 21408 disabled region) — those return immediately.
|
||||
type TwilioSender struct {
|
||||
client *twilio.RestClient
|
||||
from string
|
||||
maxAttempt int
|
||||
}
|
||||
|
||||
func NewTwilioSender(cfg TwilioConfig) (*TwilioSender, error) {
|
||||
if cfg.AccountSID == "" || cfg.AuthToken == "" || cfg.FromNumber == "" {
|
||||
return nil, errors.New("twilio: AccountSID / AuthToken / FromNumber are required")
|
||||
}
|
||||
cli := twilio.NewRestClientWithParams(twilio.ClientParams{
|
||||
Username: cfg.AccountSID,
|
||||
Password: cfg.AuthToken,
|
||||
})
|
||||
max := cfg.MaxAttempts
|
||||
if max <= 0 {
|
||||
max = 5
|
||||
}
|
||||
return &TwilioSender{client: cli, from: cfg.FromNumber, maxAttempt: max}, nil
|
||||
}
|
||||
|
||||
func (t *TwilioSender) Send(ctx context.Context, msg OutboundMessage) (string, error) {
|
||||
if msg.GuestID == uuid.Nil {
|
||||
return "", errors.New("missing guest id")
|
||||
}
|
||||
if msg.Channel != ChannelSMS {
|
||||
return "", fmt.Errorf("TwilioSender does not handle channel %q", msg.Channel)
|
||||
}
|
||||
to, _ := msg.Metadata["phone"].(string)
|
||||
if to == "" {
|
||||
return "", errors.New("sms recipient missing from metadata.phone")
|
||||
}
|
||||
body := msg.Body
|
||||
if body == "" {
|
||||
body = msg.Subject
|
||||
}
|
||||
if body == "" {
|
||||
return "", errors.New("sms body is empty")
|
||||
}
|
||||
|
||||
params := &openapi.CreateMessageParams{}
|
||||
params.SetTo(to)
|
||||
params.SetFrom(t.from)
|
||||
params.SetBody(body)
|
||||
|
||||
backoff := []time.Duration{0, time.Second, 5 * time.Second, 30 * time.Second, 5 * time.Minute, 30 * time.Minute}
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < t.maxAttempt; attempt++ {
|
||||
if attempt < len(backoff) && backoff[attempt] > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(backoff[attempt]):
|
||||
}
|
||||
}
|
||||
resp, err := t.client.Api.CreateMessage(params)
|
||||
if err == nil && resp != nil && resp.Sid != nil {
|
||||
return *resp.Sid, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !isTwilioRetryable(err) {
|
||||
return "", fmt.Errorf("twilio: send: %w", err)
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("twilio: send (after %d attempts): %w", t.maxAttempt, lastErr)
|
||||
}
|
||||
|
||||
// isTwilioRetryable returns true for transient failures (network, 5xx).
|
||||
// Twilio's permanent error codes (21xxx range) are not retried.
|
||||
func isTwilioRetryable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
// Cheap heuristic: permanent codes are in the 21xxx range; everything
|
||||
// else (timeouts, 503s, DNS hiccups) is fair game.
|
||||
if strings.Contains(msg, "21610") || strings.Contains(msg, "21408") || strings.Contains(msg, "21211") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,238 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SMTPConfig describes one SMTP relay. Username/Password are optional —
|
||||
// local relays like Mailpit accept anonymous SMTP. TLS modes:
|
||||
//
|
||||
// - "starttls" upgrade after EHLO (most relays, default)
|
||||
// - "implicit" TLS handshake before SMTP (port 465)
|
||||
// - "none" plain socket — only acceptable on a trusted LAN
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromEmail string
|
||||
FromName string
|
||||
TLS string
|
||||
}
|
||||
|
||||
// SMTPEmailSender implements auth.EmailSender (verification/reset) AND
|
||||
// GuestEmailDispatcher (invitation/confirmation/reminder) on top of any
|
||||
// SMTP relay. Used for Mailpit in dev; works against Gmail, Fastmail, etc.
|
||||
// in production if the user prefers plain SMTP over an HTTP API.
|
||||
type SMTPEmailSender struct {
|
||||
cfg SMTPConfig
|
||||
tpls *Templates
|
||||
from string
|
||||
}
|
||||
|
||||
func NewSMTPEmailSender(cfg SMTPConfig, tpls *Templates) (*SMTPEmailSender, error) {
|
||||
if cfg.Host == "" {
|
||||
return nil, errors.New("smtp: Host required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
cfg.Port = 587
|
||||
}
|
||||
if cfg.FromEmail == "" {
|
||||
return nil, errors.New("smtp: FromEmail required")
|
||||
}
|
||||
if cfg.TLS == "" {
|
||||
cfg.TLS = "starttls"
|
||||
}
|
||||
from := cfg.FromEmail
|
||||
if cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
|
||||
}
|
||||
return &SMTPEmailSender{cfg: cfg, tpls: tpls, from: from}, nil
|
||||
}
|
||||
|
||||
// --- auth.EmailSender ---
|
||||
|
||||
func (s *SMTPEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Verify your GuestGuard email",
|
||||
TmplVerification, map[string]any{"Name": name, "Link": link})
|
||||
}
|
||||
|
||||
func (s *SMTPEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Reset your GuestGuard password",
|
||||
TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"})
|
||||
}
|
||||
|
||||
// --- GuestEmailDispatcher ---
|
||||
|
||||
func (s *SMTPEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
return s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func (s *SMTPEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error {
|
||||
_, err := s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SMTPEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
html, text, err := s.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
msgID := generateMessageID(s.cfg.FromEmail)
|
||||
body := buildMIMEMessage(mimeMessage{
|
||||
MessageID: msgID,
|
||||
From: s.from,
|
||||
To: to,
|
||||
Subject: subject,
|
||||
Text: text,
|
||||
HTML: html,
|
||||
})
|
||||
if err := s.dial(ctx, []string{to}, body); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return msgID, nil
|
||||
}
|
||||
|
||||
func (s *SMTPEmailSender) dial(ctx context.Context, to []string, body []byte) error {
|
||||
addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(s.cfg.Port))
|
||||
d := &net.Dialer{Timeout: 10 * time.Second}
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
d.Deadline = deadline
|
||||
}
|
||||
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
)
|
||||
switch strings.ToLower(s.cfg.TLS) {
|
||||
case "implicit":
|
||||
conn, err = tls.DialWithDialer(d, "tcp", addr, &tls.Config{ServerName: s.cfg.Host})
|
||||
default:
|
||||
conn, err = d.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp: dial: %w", err)
|
||||
}
|
||||
|
||||
c, err := smtp.NewClient(conn, s.cfg.Host)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("smtp: new client: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if strings.ToLower(s.cfg.TLS) == "starttls" {
|
||||
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||
if err := c.StartTLS(&tls.Config{ServerName: s.cfg.Host}); err != nil {
|
||||
return fmt.Errorf("smtp: starttls: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.cfg.Username != "" {
|
||||
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host)
|
||||
if err := c.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp: auth: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Mail(s.cfg.FromEmail); err != nil {
|
||||
return fmt.Errorf("smtp: MAIL FROM: %w", err)
|
||||
}
|
||||
for _, rcpt := range to {
|
||||
if err := c.Rcpt(rcpt); err != nil {
|
||||
return fmt.Errorf("smtp: RCPT TO %s: %w", rcpt, err)
|
||||
}
|
||||
}
|
||||
w, err := c.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp: DATA: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
_ = w.Close()
|
||||
return fmt.Errorf("smtp: write body: %w", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return fmt.Errorf("smtp: close body: %w", err)
|
||||
}
|
||||
return c.Quit()
|
||||
}
|
||||
|
||||
type mimeMessage struct {
|
||||
MessageID string
|
||||
From string
|
||||
To string
|
||||
Subject string
|
||||
Text string
|
||||
HTML string
|
||||
}
|
||||
|
||||
// buildMIMEMessage assembles an RFC 5322 message with a multipart/alternative
|
||||
// body so receiving clients pick HTML when they can and fall back to text
|
||||
// otherwise.
|
||||
func buildMIMEMessage(m mimeMessage) []byte {
|
||||
boundary := randomBoundary()
|
||||
var b strings.Builder
|
||||
b.WriteString("Message-ID: <" + m.MessageID + ">\r\n")
|
||||
b.WriteString("Date: " + time.Now().UTC().Format(time.RFC1123Z) + "\r\n")
|
||||
b.WriteString("From: " + m.From + "\r\n")
|
||||
b.WriteString("To: " + m.To + "\r\n")
|
||||
b.WriteString("Subject: " + m.Subject + "\r\n")
|
||||
b.WriteString("MIME-Version: 1.0\r\n")
|
||||
b.WriteString("Content-Type: multipart/alternative; boundary=" + boundary + "\r\n")
|
||||
b.WriteString("\r\n")
|
||||
|
||||
b.WriteString("--" + boundary + "\r\n")
|
||||
b.WriteString("Content-Type: text/plain; charset=UTF-8\r\n")
|
||||
b.WriteString("Content-Transfer-Encoding: 8bit\r\n")
|
||||
b.WriteString("\r\n")
|
||||
b.WriteString(m.Text)
|
||||
if !strings.HasSuffix(m.Text, "\n") {
|
||||
b.WriteString("\r\n")
|
||||
}
|
||||
|
||||
b.WriteString("--" + boundary + "\r\n")
|
||||
b.WriteString("Content-Type: text/html; charset=UTF-8\r\n")
|
||||
b.WriteString("Content-Transfer-Encoding: 8bit\r\n")
|
||||
b.WriteString("\r\n")
|
||||
b.WriteString(m.HTML)
|
||||
if !strings.HasSuffix(m.HTML, "\n") {
|
||||
b.WriteString("\r\n")
|
||||
}
|
||||
|
||||
b.WriteString("--" + boundary + "--\r\n")
|
||||
return []byte(b.String())
|
||||
}
|
||||
|
||||
func randomBoundary() string {
|
||||
var buf [16]byte
|
||||
_, _ = rand.Read(buf[:])
|
||||
return "gg=" + hex.EncodeToString(buf[:])
|
||||
}
|
||||
|
||||
func generateMessageID(from string) string {
|
||||
var buf [12]byte
|
||||
_, _ = rand.Read(buf[:])
|
||||
domain := "guestguard.local"
|
||||
if at := strings.LastIndexByte(from, '@'); at >= 0 && at+1 < len(from) {
|
||||
domain = from[at+1:]
|
||||
}
|
||||
return hex.EncodeToString(buf[:]) + "@" + domain
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildMIMEMessageStructure(t *testing.T) {
|
||||
body := buildMIMEMessage(mimeMessage{
|
||||
MessageID: "abc@example.test",
|
||||
From: "GuestGuard <no-reply@example.test>",
|
||||
To: "to@example.test",
|
||||
Subject: "Verify your GuestGuard email",
|
||||
Text: "Hi Mira, please verify.",
|
||||
HTML: "<p>Hi Mira, please verify.</p>",
|
||||
})
|
||||
s := string(body)
|
||||
checks := []string{
|
||||
"Message-ID: <abc@example.test>",
|
||||
"From: GuestGuard <no-reply@example.test>",
|
||||
"To: to@example.test",
|
||||
"Subject: Verify your GuestGuard email",
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: multipart/alternative; boundary=",
|
||||
"Content-Type: text/plain; charset=UTF-8",
|
||||
"Content-Type: text/html; charset=UTF-8",
|
||||
"Hi Mira, please verify.",
|
||||
"<p>Hi Mira, please verify.</p>",
|
||||
}
|
||||
for _, want := range checks {
|
||||
if !strings.Contains(s, want) {
|
||||
t.Errorf("MIME body missing %q\n---\n%s", want, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageIDIncludesDomain(t *testing.T) {
|
||||
id := generateMessageID("no-reply@example.test")
|
||||
if !strings.HasSuffix(id, "@example.test") {
|
||||
t.Fatalf("message id has wrong domain: %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageIDFallback(t *testing.T) {
|
||||
id := generateMessageID("not-an-email")
|
||||
if !strings.HasSuffix(id, "@guestguard.local") {
|
||||
t.Fatalf("expected fallback domain: %s", id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// SuppressionSource categorises why an address landed on the suppression
|
||||
// list. Bounces + complaints come from provider webhooks; "user" is set
|
||||
// when a guest clicks an unsubscribe link.
|
||||
type SuppressionSource string
|
||||
|
||||
const (
|
||||
SuppressionBounce SuppressionSource = "bounce"
|
||||
SuppressionComplaint SuppressionSource = "complaint"
|
||||
SuppressionManual SuppressionSource = "manual"
|
||||
SuppressionUser SuppressionSource = "user"
|
||||
)
|
||||
|
||||
// SuppressionRepo manages the unsubscribes table — a flat suppression list
|
||||
// of email addresses that must never receive another email regardless of
|
||||
// notification type.
|
||||
type SuppressionRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewSuppressionRepo(db *storage.DB) *SuppressionRepo {
|
||||
return &SuppressionRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
// IsSuppressed returns true if `email` is on the suppression list.
|
||||
// Empty / unparseable addresses are treated as not-suppressed; the caller
|
||||
// is expected to validate before sending.
|
||||
func (r *SuppressionRepo) IsSuppressed(ctx context.Context, email string) (bool, error) {
|
||||
email = normaliseEmail(email)
|
||||
if email == "" {
|
||||
return false, nil
|
||||
}
|
||||
var exists bool
|
||||
err := r.pool.QueryRow(ctx,
|
||||
`SELECT EXISTS (SELECT 1 FROM unsubscribes WHERE email = $1)`,
|
||||
email,
|
||||
).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Add records `email` on the suppression list. Idempotent — repeated calls
|
||||
// keep the earliest entry's timestamp.
|
||||
func (r *SuppressionRepo) Add(ctx context.Context, email, reason string, src SuppressionSource) error {
|
||||
email = normaliseEmail(email)
|
||||
if email == "" {
|
||||
return errors.New("empty email")
|
||||
}
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO unsubscribes (email, reason, source)
|
||||
VALUES ($1, NULLIF($2, ''), $3)
|
||||
ON CONFLICT (email) DO NOTHING
|
||||
`, email, reason, string(src))
|
||||
return err
|
||||
}
|
||||
|
||||
// Get returns the suppression record for `email`, or pgx.ErrNoRows if not
|
||||
// found. Mostly used by tests / admin tooling.
|
||||
func (r *SuppressionRepo) Get(ctx context.Context, email string) (string, SuppressionSource, error) {
|
||||
var reason *string
|
||||
var source string
|
||||
err := r.pool.QueryRow(ctx,
|
||||
`SELECT reason, source FROM unsubscribes WHERE email = $1`,
|
||||
normaliseEmail(email),
|
||||
).Scan(&reason, &source)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", err
|
||||
}
|
||||
return "", "", err
|
||||
}
|
||||
r2 := ""
|
||||
if reason != nil {
|
||||
r2 = *reason
|
||||
}
|
||||
return r2, SuppressionSource(source), nil
|
||||
}
|
||||
|
||||
func normaliseEmail(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
htmltemplate "html/template"
|
||||
"io/fs"
|
||||
"path"
|
||||
"strings"
|
||||
texttemplate "text/template"
|
||||
)
|
||||
|
||||
//go:embed templates
|
||||
var templatesFS embed.FS
|
||||
|
||||
// TemplateName names one of the email templates. There must be a
|
||||
// matching `<name>.html` and `<name>.txt` under templates/.
|
||||
type TemplateName string
|
||||
|
||||
const (
|
||||
TmplVerification TemplateName = "verification"
|
||||
TmplPasswordReset TemplateName = "reset"
|
||||
TmplInvitation TemplateName = "invitation"
|
||||
TmplConfirmation TemplateName = "confirmation"
|
||||
TmplReminder TemplateName = "reminder"
|
||||
)
|
||||
|
||||
// Templates renders branded transactional emails for both HTML and
|
||||
// plaintext bodies. Loaded once at construction; thread-safe afterwards.
|
||||
//
|
||||
// Each page-level HTML template gets its own *html/template.Template
|
||||
// holding a copy of the `base.html` partial. This avoids html/template's
|
||||
// per-template contextual-escape pass interfering between pages that all
|
||||
// define a sibling named "body".
|
||||
type Templates struct {
|
||||
html map[string]*htmltemplate.Template // page-name (no ext) -> root
|
||||
text *texttemplate.Template
|
||||
}
|
||||
|
||||
func NewTemplates() (*Templates, error) {
|
||||
root, err := fs.Sub(templatesFS, "templates")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("templates fs: %w", err)
|
||||
}
|
||||
|
||||
baseHTML, err := fs.ReadFile(root, "base.html")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read base.html: %w", err)
|
||||
}
|
||||
|
||||
out := &Templates{
|
||||
html: make(map[string]*htmltemplate.Template),
|
||||
text: texttemplate.New("__root__"),
|
||||
}
|
||||
|
||||
walk := func(p string, d fs.DirEntry, _ error) error {
|
||||
if d == nil || d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
base := path.Base(p)
|
||||
if base == "base.html" {
|
||||
return nil // partial — folded into each page template below
|
||||
}
|
||||
b, err := fs.ReadFile(root, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case strings.HasSuffix(p, ".html"):
|
||||
name := strings.TrimSuffix(base, ".html")
|
||||
t := htmltemplate.New(name)
|
||||
if _, err := t.Parse(string(baseHTML)); err != nil {
|
||||
return fmt.Errorf("parse _base for %s: %w", p, err)
|
||||
}
|
||||
if _, err := t.Parse(string(b)); err != nil {
|
||||
return fmt.Errorf("parse %s: %w", p, err)
|
||||
}
|
||||
out.html[name] = t
|
||||
case strings.HasSuffix(p, ".txt"):
|
||||
if _, err := out.text.New(base).Parse(string(b)); err != nil {
|
||||
return fmt.Errorf("parse %s: %w", p, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := fs.WalkDir(root, ".", walk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Render returns (htmlBody, textBody) for the named template using data.
|
||||
func (t *Templates) Render(name TemplateName, data map[string]any) (htmlBody, textBody string, err error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
if _, ok := data["Subject"]; !ok {
|
||||
data["Subject"] = "GuestGuard"
|
||||
}
|
||||
|
||||
htmlTpl, ok := t.html[string(name)]
|
||||
if !ok {
|
||||
return "", "", fmt.Errorf("unknown html template %q", name)
|
||||
}
|
||||
|
||||
var hBuf, tBuf bytes.Buffer
|
||||
// Each page-root template's entry point is "base" (defined by base.html).
|
||||
if err := htmlTpl.ExecuteTemplate(&hBuf, "base", data); err != nil {
|
||||
return "", "", fmt.Errorf("render html %s: %w", name, err)
|
||||
}
|
||||
if err := t.text.ExecuteTemplate(&tBuf, string(name)+".txt", data); err != nil {
|
||||
return "", "", fmt.Errorf("render text %s: %w", name, err)
|
||||
}
|
||||
return hBuf.String(), tBuf.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
{{define "base"}}<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>{{.Subject}}</title>
|
||||
</head>
|
||||
<body style="margin:0;padding:0;background:#f7f7f8;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,Helvetica,Arial,sans-serif;color:#0f172a;">
|
||||
<table role="presentation" width="100%" cellpadding="0" cellspacing="0" style="background:#f7f7f8;padding:32px 16px;">
|
||||
<tr><td align="center">
|
||||
<table role="presentation" width="560" cellpadding="0" cellspacing="0" style="background:#ffffff;border-radius:12px;overflow:hidden;box-shadow:0 1px 2px rgba(15,23,42,0.05);">
|
||||
<tr>
|
||||
<td style="background:#0a0a0a;padding:20px 28px;">
|
||||
<span style="display:inline-block;width:10px;height:10px;background:#22c55e;border-radius:50%;vertical-align:middle;"></span>
|
||||
<span style="color:#fafafa;font-weight:600;font-size:16px;margin-left:8px;vertical-align:middle;">GuestGuard</span>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:32px 28px 24px;font-size:15px;line-height:1.55;color:#0f172a;">
|
||||
{{block "body" .}}{{end}}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:0 28px 28px;">
|
||||
<hr style="border:none;border-top:1px solid #e5e7eb;margin:0 0 16px;">
|
||||
<p style="font-size:12px;color:#64748b;margin:0;">
|
||||
You're receiving this because of activity on your GuestGuard account.
|
||||
{{if .UnsubscribeLink}}<br>If you'd rather not get emails like this, <a href="{{.UnsubscribeLink}}" style="color:#16a34a;">unsubscribe here</a>.{{end}}
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td></tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>{{end}}
|
||||
@@ -0,0 +1,11 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">RSVP received</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.GuestName}},</p>
|
||||
<p style="margin:0 0 16px;">Thanks for letting {{.HostName}} know — your RSVP for <strong>{{.EventName}}</strong> is confirmed as <strong>{{.Response}}</strong>{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.</p>
|
||||
{{if or .Venue .EventDate}}
|
||||
<p style="margin:0 0 16px;color:#64748b;font-size:14px;">
|
||||
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
|
||||
</p>
|
||||
{{end}}
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">You'll get a reminder closer to the date. If your plans change, use the same invitation link to update your reply.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,10 @@
|
||||
Hi {{.GuestName}},
|
||||
|
||||
Your RSVP for "{{.EventName}}" is confirmed as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.
|
||||
|
||||
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
|
||||
|
||||
You'll get a reminder closer to the date. If your plans change, use the
|
||||
same invitation link to update your reply.
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,15 @@
|
||||
{{define "body"}}
|
||||
<p style="font-size:13px;letter-spacing:0.2em;text-transform:uppercase;color:#16a34a;margin:0 0 16px;">✦ You're invited</p>
|
||||
<h1 style="font-size:24px;margin:0 0 4px;color:#0a0a0a;">{{.EventName}}</h1>
|
||||
{{if or .Venue .EventDate}}
|
||||
<p style="margin:0 0 24px;color:#64748b;font-size:14px;">
|
||||
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
|
||||
</p>
|
||||
{{end}}
|
||||
<p style="margin:0 0 20px;">Hi {{.GuestName}}, {{.HostName}} would love to know if you can make it. Use the personal link below to RSVP.</p>
|
||||
<p style="margin:0 0 24px;text-align:center;">
|
||||
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">RSVP now</a>
|
||||
</p>
|
||||
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
|
||||
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,11 @@
|
||||
You're invited — {{.EventName}}
|
||||
|
||||
Hi {{.GuestName}},
|
||||
|
||||
{{.HostName}} would love to know if you can make it{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.
|
||||
|
||||
RSVP here:
|
||||
|
||||
{{.Link}}
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,7 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">Reminder: {{.EventName}}</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.GuestName}},</p>
|
||||
<p style="margin:0 0 16px;">Just a quick reminder that <strong>{{.EventName}}</strong> is coming up{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.</p>
|
||||
{{if .Response}}<p style="margin:0 0 16px;">You're down as <strong>{{.Response}}</strong>{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.</p>{{end}}
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">Need to change your plans? Use your invitation link.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,7 @@
|
||||
Hi {{.GuestName}},
|
||||
|
||||
Reminder: {{.EventName}}{{if .EventDate}} — {{.EventDate}}{{end}}{{if .Venue}} at {{.Venue}}{{end}}.
|
||||
|
||||
{{if .Response}}You're down as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.{{end}}
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,11 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">Reset your password</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.Name}},</p>
|
||||
<p style="margin:0 0 20px;">We received a request to reset the password on your GuestGuard account. Tap the button to choose a new one — the link is valid for {{.ExpiryHumane}}.</p>
|
||||
<p style="margin:0 0 24px;text-align:center;">
|
||||
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">Choose a new password</a>
|
||||
</p>
|
||||
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
|
||||
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">If you didn't ask to reset your password, you can ignore this email — your current password is unchanged.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,11 @@
|
||||
Hi {{.Name}},
|
||||
|
||||
We received a request to reset your GuestGuard password. The link is valid
|
||||
for {{.ExpiryHumane}}:
|
||||
|
||||
{{.Link}}
|
||||
|
||||
If you didn't ask to reset your password, you can ignore this email — your
|
||||
current password is unchanged.
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,11 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">Verify your email</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.Name}}, welcome to GuestGuard.</p>
|
||||
<p style="margin:0 0 20px;">To finish setting up your account, please confirm this is your email address.</p>
|
||||
<p style="margin:0 0 24px;text-align:center;">
|
||||
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">Verify email</a>
|
||||
</p>
|
||||
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
|
||||
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">If you didn't sign up for GuestGuard, you can ignore this email.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,9 @@
|
||||
Hi {{.Name}}, welcome to GuestGuard.
|
||||
|
||||
Please verify your email by visiting:
|
||||
|
||||
{{.Link}}
|
||||
|
||||
If you didn't sign up for GuestGuard, you can ignore this email.
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,101 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRenderAllTemplates(t *testing.T) {
|
||||
tpls, err := NewTemplates()
|
||||
if err != nil {
|
||||
t.Fatalf("NewTemplates: %v", err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name TemplateName
|
||||
data map[string]any
|
||||
wantHTML []string // substrings expected in HTML body
|
||||
wantText []string // substrings expected in text body
|
||||
}{
|
||||
{
|
||||
name: TmplVerification,
|
||||
data: map[string]any{
|
||||
"Name": "Kay",
|
||||
"Link": "https://example.test/verify-email?token=x",
|
||||
"Subject": "Verify your GuestGuard email",
|
||||
"UnsubscribeLink": "https://example.test/unsubscribe/abc",
|
||||
},
|
||||
wantHTML: []string{"Verify your email", "Kay", "Verify email", "https://example.test/verify-email?token=x", "unsubscribe here"},
|
||||
wantText: []string{"Hi Kay", "https://example.test/verify-email?token=x"},
|
||||
},
|
||||
{
|
||||
name: TmplPasswordReset,
|
||||
data: map[string]any{
|
||||
"Name": "Kay",
|
||||
"Link": "https://example.test/reset-password/abc",
|
||||
"ExpiryHumane": "1 hour",
|
||||
},
|
||||
wantHTML: []string{"Reset your password", "1 hour", "https://example.test/reset-password/abc"},
|
||||
wantText: []string{"reset your GuestGuard password", "1 hour"},
|
||||
},
|
||||
{
|
||||
name: TmplInvitation,
|
||||
data: map[string]any{
|
||||
"GuestName": "Mira",
|
||||
"HostName": "Kay",
|
||||
"EventName": "Beach Day",
|
||||
"Venue": "Ocean Park",
|
||||
"EventDate": "Sat 14 Jun, 4pm",
|
||||
"Link": "https://example.test/rsvp/tok_x",
|
||||
},
|
||||
wantHTML: []string{"You're invited", "Beach Day", "Mira", "Ocean Park", "RSVP now", "https://example.test/rsvp/tok_x"},
|
||||
wantText: []string{"Beach Day", "Mira", "RSVP here", "https://example.test/rsvp/tok_x"},
|
||||
},
|
||||
{
|
||||
name: TmplConfirmation,
|
||||
data: map[string]any{
|
||||
"GuestName": "Mira",
|
||||
"HostName": "Kay",
|
||||
"EventName": "Beach Day",
|
||||
"Venue": "Ocean Park",
|
||||
"EventDate": "Sat 14 Jun, 4pm",
|
||||
"Response": "attending",
|
||||
"PlusOnes": 2,
|
||||
},
|
||||
wantHTML: []string{"RSVP received", "Beach Day", "attending", "2 plus-ones"},
|
||||
wantText: []string{"Beach Day", "attending", "+2"},
|
||||
},
|
||||
{
|
||||
name: TmplReminder,
|
||||
data: map[string]any{
|
||||
"GuestName": "Mira",
|
||||
"EventName": "Beach Day",
|
||||
"EventDate": "Sat 14 Jun, 4pm",
|
||||
"Venue": "Ocean Park",
|
||||
"Response": "attending",
|
||||
"PlusOnes": 1,
|
||||
},
|
||||
wantHTML: []string{"Reminder", "Beach Day", "Ocean Park", "1 plus-one"},
|
||||
wantText: []string{"Reminder", "Beach Day", "(+1)"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(string(tc.name), func(t *testing.T) {
|
||||
html, text, err := tpls.Render(tc.name, tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("render: %v", err)
|
||||
}
|
||||
for _, s := range tc.wantHTML {
|
||||
if !strings.Contains(html, s) {
|
||||
t.Errorf("html missing %q\n---\n%s", s, html)
|
||||
}
|
||||
}
|
||||
for _, s := range tc.wantText {
|
||||
if !strings.Contains(text, s) {
|
||||
t.Errorf("text missing %q\n---\n%s", s, text)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UnsubscribeSigner mints + verifies tamper-proof unsubscribe tokens. Each
|
||||
// token encodes the email address and an HMAC-SHA256 of the address under
|
||||
// a server-side secret. The token has no TTL — unsubscribe links should
|
||||
// keep working forever.
|
||||
//
|
||||
// Token shape: base64url(email) + "." + base64url(hmac)
|
||||
type UnsubscribeSigner struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func NewUnsubscribeSigner(secret string) *UnsubscribeSigner {
|
||||
return &UnsubscribeSigner{secret: []byte(secret)}
|
||||
}
|
||||
|
||||
// Sign returns a URL-safe token for the email address. Empty input → empty
|
||||
// token (caller should validate input first).
|
||||
func (s *UnsubscribeSigner) Sign(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
email = normaliseEmail(email)
|
||||
mac := hmac.New(sha256.New, s.secret)
|
||||
mac.Write([]byte(email))
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(email)) +
|
||||
"." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// Verify decodes the token and confirms the HMAC matches, returning the
|
||||
// owning email address.
|
||||
func (s *UnsubscribeSigner) Verify(token string) (string, error) {
|
||||
dot := strings.IndexByte(token, '.')
|
||||
if dot < 0 {
|
||||
return "", errors.New("malformed unsubscribe token")
|
||||
}
|
||||
emailB, err := base64.RawURLEncoding.DecodeString(token[:dot])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sigB, err := base64.RawURLEncoding.DecodeString(token[dot+1:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
mac := hmac.New(sha256.New, s.secret)
|
||||
mac.Write(emailB)
|
||||
want := mac.Sum(nil)
|
||||
if !hmac.Equal(sigB, want) {
|
||||
return "", errors.New("unsubscribe signature mismatch")
|
||||
}
|
||||
return string(emailB), nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KeyFunc derives the rate-limit key from a request — e.g. IP, IP+email,
|
||||
// authenticated user id, path token, etc. An empty return value bypasses
|
||||
// the limiter (handy when a key isn't available yet, like an email that
|
||||
// hasn't been parsed out of the body).
|
||||
type KeyFunc func(r *http.Request) string
|
||||
|
||||
// Rule names a single sliding-window budget.
|
||||
type Rule struct {
|
||||
Name string
|
||||
Limit int
|
||||
Window time.Duration
|
||||
}
|
||||
|
||||
// Middleware returns an http middleware that applies `rule` to incoming
|
||||
// requests using `keyFn`. On limit, it writes 429 with a Retry-After header
|
||||
// and a JSON body. If the limiter itself errors, requests are allowed
|
||||
// (fail-open) — degraded protection is better than total outage.
|
||||
func (l *Limiter) Middleware(rule Rule, keyFn KeyFunc, logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := keyFn(r)
|
||||
if key == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
res, err := l.Allow(r.Context(), rule.Name, key, rule.Limit, rule.Window)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("ratelimit error (failing open)",
|
||||
"rule", rule.Name, "err", err)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !res.Allowed {
|
||||
retrySecs := int(res.RetryAfter.Round(time.Second).Seconds())
|
||||
if retrySecs < 1 {
|
||||
retrySecs = 1
|
||||
}
|
||||
w.Header().Set("Retry-After", strconv.Itoa(retrySecs))
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": "rate limit exceeded",
|
||||
"retry_after": retrySecs,
|
||||
})
|
||||
if logger != nil {
|
||||
logger.Info("ratelimit blocked",
|
||||
"rule", rule.Name,
|
||||
"key", key,
|
||||
"retry_after", retrySecs,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
// Package ratelimit implements a sliding-window rate limiter backed by
|
||||
// Redis sorted sets. Each call is atomic via a Lua script: it sweeps
|
||||
// entries older than the window, returns the current count, and (when
|
||||
// under the limit) records the new hit. Block C's "INCR + EXPIRE or a Lua
|
||||
// script for atomicity" requirement.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Result is the outcome of one Allow check.
|
||||
type Result struct {
|
||||
Allowed bool
|
||||
Count int // current count within the window (post-increment when allowed)
|
||||
Limit int
|
||||
RetryAfter time.Duration // populated when Allowed=false
|
||||
}
|
||||
|
||||
// Limiter checks rate-limit budgets against Redis.
|
||||
type Limiter struct {
|
||||
client *redis.Client
|
||||
script *redis.Script
|
||||
prefix string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// New builds a Limiter against the given Redis client. The prefix namespaces
|
||||
// all keys (defaults to "rl").
|
||||
func New(client *redis.Client, prefix string) *Limiter {
|
||||
if prefix == "" {
|
||||
prefix = "rl"
|
||||
}
|
||||
return &Limiter{
|
||||
client: client,
|
||||
script: redis.NewScript(slidingWindowScript),
|
||||
prefix: prefix,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow consumes one unit of budget under (name, key) against `limit` events
|
||||
// per `window`. Returns Allowed=true and the new count, or Allowed=false
|
||||
// with RetryAfter set to roughly the duration until the oldest hit ages out.
|
||||
func (l *Limiter) Allow(ctx context.Context, name, key string, limit int, window time.Duration) (Result, error) {
|
||||
if limit <= 0 {
|
||||
return Result{}, errors.New("ratelimit: limit must be positive")
|
||||
}
|
||||
if window <= 0 {
|
||||
return Result{}, errors.New("ratelimit: window must be positive")
|
||||
}
|
||||
member, err := randomMember()
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
now := l.now().UnixMilli()
|
||||
windowMS := window.Milliseconds()
|
||||
redisKey := fmt.Sprintf("%s:%s:%s", l.prefix, name, key)
|
||||
|
||||
out, err := l.script.Run(ctx, l.client,
|
||||
[]string{redisKey},
|
||||
now, windowMS, limit, member,
|
||||
).Int64Slice()
|
||||
if err != nil {
|
||||
return Result{}, fmt.Errorf("ratelimit: redis: %w", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
return Result{}, fmt.Errorf("ratelimit: bad lua reply: %v", out)
|
||||
}
|
||||
|
||||
r := Result{
|
||||
Count: int(out[1]),
|
||||
Limit: limit,
|
||||
}
|
||||
if out[0] == 0 {
|
||||
r.Allowed = true
|
||||
} else {
|
||||
r.RetryAfter = time.Duration(out[2]) * time.Millisecond
|
||||
if r.RetryAfter <= 0 {
|
||||
r.RetryAfter = time.Second
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func randomMember() (string, error) {
|
||||
var buf [12]byte
|
||||
if _, err := rand.Read(buf[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf[:]), nil
|
||||
}
|
||||
|
||||
// Sliding-window check + record, atomic in Redis.
|
||||
//
|
||||
// KEYS[1] = bucket key
|
||||
// ARGV[1] = now (unix ms)
|
||||
// ARGV[2] = window (ms)
|
||||
// ARGV[3] = limit
|
||||
// ARGV[4] = unique member to insert when allowed
|
||||
// returns { blocked, count, retryAfterMs }
|
||||
const slidingWindowScript = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
local member = ARGV[4]
|
||||
local cutoff = now - window
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
|
||||
local count = redis.call('ZCARD', key)
|
||||
|
||||
if count >= limit then
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||
local retry = window
|
||||
if oldest[2] then
|
||||
retry = (tonumber(oldest[2]) + window) - now
|
||||
if retry < 1 then retry = 1 end
|
||||
end
|
||||
return {1, count, retry}
|
||||
end
|
||||
|
||||
redis.call('ZADD', key, now, member)
|
||||
redis.call('PEXPIRE', key, window)
|
||||
return {0, count + 1, 0}
|
||||
`
|
||||
@@ -0,0 +1,110 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func newTestLimiter(t *testing.T) (*Limiter, *miniredis.Miniredis) {
|
||||
t.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("miniredis: %v", err)
|
||||
}
|
||||
t.Cleanup(mr.Close)
|
||||
cli := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
t.Cleanup(func() { _ = cli.Close() })
|
||||
return New(cli, "test"), mr
|
||||
}
|
||||
|
||||
func TestLimiterAllowsBelowLimit(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
for i := 1; i <= 3; i++ {
|
||||
r, err := l.Allow(ctx, "signup", "1.2.3.4", 3, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("allow #%d: %v", i, err)
|
||||
}
|
||||
if !r.Allowed {
|
||||
t.Fatalf("hit %d should be allowed, got %+v", i, r)
|
||||
}
|
||||
if r.Count != i {
|
||||
t.Fatalf("hit %d count: got %d want %d", i, r.Count, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterBlocksAtLimit(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := l.Allow(ctx, "signup", "ip", 3, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
r, err := l.Allow(ctx, "signup", "ip", 3, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.Allowed {
|
||||
t.Fatalf("4th hit should be blocked: %+v", r)
|
||||
}
|
||||
if r.RetryAfter <= 0 || r.RetryAfter > time.Minute {
|
||||
t.Fatalf("retry-after out of range: %v", r.RetryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterWindowSlides(t *testing.T) {
|
||||
l, mr := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
// Inject a controllable clock so we can advance time in miniredis +
|
||||
// the limiter consistently.
|
||||
base := time.Unix(1_700_000_000, 0)
|
||||
l.now = func() time.Time { return base }
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
// Slide past the window. miniredis honours TTLs we already set so
|
||||
// FastForward is the trustworthy primitive here.
|
||||
l.now = func() time.Time { return base.Add(2 * time.Minute) }
|
||||
mr.FastForward(2 * time.Minute)
|
||||
|
||||
r, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !r.Allowed {
|
||||
t.Fatalf("expected allow after window: %+v", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterIsolatesKeys(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
// Exhaust budget for one key — others should not be affected.
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
blocked, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if blocked.Allowed {
|
||||
t.Fatal("key a should be blocked")
|
||||
}
|
||||
other, err := l.Allow(ctx, "login", "b@x.com", 2, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !other.Allowed {
|
||||
t.Fatalf("unrelated key b should still be allowed: %+v", other)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
)
|
||||
|
||||
// EmailVerificationRepo manages single-use email verification tokens.
|
||||
type EmailVerificationRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewEmailVerificationRepo(db *DB) *EmailVerificationRepo {
|
||||
return &EmailVerificationRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
func (r *EmailVerificationRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO email_verification_tokens (token_hash, user_id, expires_at)
|
||||
VALUES ($1, $2, $3)
|
||||
`, hash, userID, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// Consume atomically marks the token as used and returns the owning user_id.
|
||||
// Returns ErrAuthTokenNotFound / ErrAuthTokenConsumed / ErrAuthTokenExpired.
|
||||
func (r *EmailVerificationRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) {
|
||||
const q = `
|
||||
UPDATE email_verification_tokens
|
||||
SET consumed_at = now()
|
||||
WHERE token_hash = $1
|
||||
AND consumed_at IS NULL
|
||||
AND expires_at > now()
|
||||
RETURNING user_id
|
||||
`
|
||||
var uid uuid.UUID
|
||||
if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool,
|
||||
"SELECT consumed_at, expires_at FROM email_verification_tokens WHERE token_hash=$1",
|
||||
hash)
|
||||
}
|
||||
return uuid.Nil, err
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// PasswordResetRepo manages single-use password-reset tokens.
|
||||
type PasswordResetRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPasswordResetRepo(db *DB) *PasswordResetRepo {
|
||||
return &PasswordResetRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
func (r *PasswordResetRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO password_reset_tokens (token_hash, user_id, expires_at)
|
||||
VALUES ($1, $2, $3)
|
||||
`, hash, userID, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PasswordResetRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) {
|
||||
const q = `
|
||||
UPDATE password_reset_tokens
|
||||
SET consumed_at = now()
|
||||
WHERE token_hash = $1
|
||||
AND consumed_at IS NULL
|
||||
AND expires_at > now()
|
||||
RETURNING user_id
|
||||
`
|
||||
var uid uuid.UUID
|
||||
if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool,
|
||||
"SELECT consumed_at, expires_at FROM password_reset_tokens WHERE token_hash=$1",
|
||||
hash)
|
||||
}
|
||||
return uuid.Nil, err
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// RefreshTokenRepo manages refresh-token rows. Refresh tokens are rotated:
|
||||
// every refresh issues a new token and revokes the old one, recording the
|
||||
// chain in `replaced_by` so we can detect replay (a revoked token being
|
||||
// presented again triggers a family-wide revocation).
|
||||
type RefreshTokenRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepo(db *DB) *RefreshTokenRepo {
|
||||
return &RefreshTokenRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
type RefreshToken struct {
|
||||
Hash string
|
||||
UserID uuid.UUID
|
||||
ExpiresAt time.Time
|
||||
RevokedAt *time.Time
|
||||
ReplacedBy *string
|
||||
UserAgent string
|
||||
IPAddress *netip.Addr
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type CreateRefreshTokenParams struct {
|
||||
Hash string
|
||||
UserID uuid.UUID
|
||||
ExpiresAt time.Time
|
||||
UserAgent string
|
||||
IPAddress string
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) Create(ctx context.Context, p CreateRefreshTokenParams) error {
|
||||
ip := parseIP(p.IPAddress)
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, NULLIF($4, ''), $5)
|
||||
`, p.Hash, p.UserID, p.ExpiresAt, p.UserAgent, ip)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) Get(ctx context.Context, hash string) (*RefreshToken, error) {
|
||||
const q = `
|
||||
SELECT token_hash, user_id, expires_at, revoked_at, replaced_by,
|
||||
COALESCE(user_agent, ''), host(ip_address), created_at
|
||||
FROM refresh_tokens WHERE token_hash = $1
|
||||
`
|
||||
var rt RefreshToken
|
||||
var ipText *string
|
||||
if err := r.pool.QueryRow(ctx, q, hash).Scan(
|
||||
&rt.Hash, &rt.UserID, &rt.ExpiresAt, &rt.RevokedAt, &rt.ReplacedBy,
|
||||
&rt.UserAgent, &ipText, &rt.CreatedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ipText != nil && *ipText != "" {
|
||||
if addr, err := netip.ParseAddr(*ipText); err == nil {
|
||||
rt.IPAddress = &addr
|
||||
}
|
||||
}
|
||||
return &rt, nil
|
||||
}
|
||||
|
||||
// Rotate atomically (in a transaction) marks the old token revoked and
|
||||
// inserts the new one with replaced_by set. Returns ErrAuthTokenNotFound or
|
||||
// ErrRefreshTokenRevoked if the old token is missing or already revoked.
|
||||
func (r *RefreshTokenRepo) Rotate(ctx context.Context, oldHash string, next CreateRefreshTokenParams) error {
|
||||
tx, err := r.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
var revokedAt *time.Time
|
||||
var userID uuid.UUID
|
||||
var expiresAt time.Time
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT user_id, expires_at, revoked_at
|
||||
FROM refresh_tokens WHERE token_hash = $1 FOR UPDATE
|
||||
`, oldHash).Scan(&userID, &expiresAt, &revokedAt)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if revokedAt != nil {
|
||||
// Replay of a revoked refresh token — revoke the entire family.
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
`, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return domain.ErrRefreshTokenRevoked
|
||||
}
|
||||
if time.Now().After(expiresAt) {
|
||||
return domain.ErrAuthTokenExpired
|
||||
}
|
||||
if next.UserID != userID {
|
||||
return errors.New("refresh token user mismatch")
|
||||
}
|
||||
|
||||
ip := parseIP(next.IPAddress)
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, NULLIF($4, ''), $5)
|
||||
`, next.Hash, next.UserID, next.ExpiresAt, next.UserAgent, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now(), replaced_by = $2
|
||||
WHERE token_hash = $1
|
||||
`, oldHash, next.Hash); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) Revoke(ctx context.Context, hash string) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now()
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL
|
||||
`, hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) RevokeAllForUser(ctx context.Context, userID uuid.UUID) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func parseIP(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
func classifyAuthTokenLookup(ctx context.Context, pool *pgxpool.Pool, q, hash string) error {
|
||||
var consumedAt *time.Time
|
||||
var expiresAt time.Time
|
||||
if err := pool.QueryRow(ctx, q, hash).Scan(&consumedAt, &expiresAt); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if consumedAt != nil {
|
||||
return domain.ErrAuthTokenConsumed
|
||||
}
|
||||
if time.Now().After(expiresAt) {
|
||||
return domain.ErrAuthTokenExpired
|
||||
}
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
+30
-12
@@ -78,6 +78,24 @@ func (r *EventRepo) Get(ctx context.Context, id uuid.UUID) (*domain.Event, error
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
// GetForHost is the authz-aware variant of Get. It returns ErrEventNotFound
|
||||
// when the event either doesn't exist or doesn't belong to the host — by
|
||||
// merging both cases we avoid leaking existence on cross-tenant lookups.
|
||||
func (r *EventRepo) GetForHost(ctx context.Context, id, hostID uuid.UUID) (*domain.Event, error) {
|
||||
const q = `
|
||||
SELECT id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at
|
||||
FROM events WHERE id = $1 AND host_id = $2
|
||||
`
|
||||
ev, err := scanEvent(r.pool.QueryRow(ctx, q, id, hostID))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrEventNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *EventRepo) List(ctx context.Context, hostID uuid.UUID, limit, offset int) ([]*domain.Event, error) {
|
||||
if limit <= 0 || limit > 200 {
|
||||
limit = 50
|
||||
@@ -132,18 +150,18 @@ type UpdateEventParams struct {
|
||||
Status *domain.EventStatus
|
||||
}
|
||||
|
||||
func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParams) (*domain.Event, error) {
|
||||
func (r *EventRepo) Update(ctx context.Context, id, hostID uuid.UUID, p UpdateEventParams) (*domain.Event, error) {
|
||||
const q = `
|
||||
UPDATE events SET
|
||||
name = COALESCE($2, name),
|
||||
slug = COALESCE($3, slug),
|
||||
event_date = COALESCE($4, event_date),
|
||||
venue = COALESCE($5, venue),
|
||||
max_capacity = COALESCE($6, max_capacity),
|
||||
settings = COALESCE($7, settings),
|
||||
status = COALESCE($8, status),
|
||||
name = COALESCE($3, name),
|
||||
slug = COALESCE($4, slug),
|
||||
event_date = COALESCE($5, event_date),
|
||||
venue = COALESCE($6, venue),
|
||||
max_capacity = COALESCE($7, max_capacity),
|
||||
settings = COALESCE($8, settings),
|
||||
status = COALESCE($9, status),
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND host_id = $2
|
||||
RETURNING id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at
|
||||
`
|
||||
|
||||
@@ -156,7 +174,7 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam
|
||||
settingsJSON = b
|
||||
}
|
||||
|
||||
row := r.pool.QueryRow(ctx, q, id,
|
||||
row := r.pool.QueryRow(ctx, q, id, hostID,
|
||||
p.Name, p.Slug, p.EventDate, p.Venue, p.MaxCapacity, settingsJSON, p.Status,
|
||||
)
|
||||
ev, err := scanEvent(row)
|
||||
@@ -173,8 +191,8 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *EventRepo) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1`, id)
|
||||
func (r *EventRepo) Delete(ctx context.Context, id, hostID uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1 AND host_id = $2`, id, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package storage
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -88,6 +90,87 @@ func (r *GuestRepo) ListByEvent(ctx context.Context, eventID uuid.UUID, limit, o
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// BulkImportRow is one normalised guest in a CSV import batch.
|
||||
type BulkImportRow struct {
|
||||
Name string
|
||||
Email string // empty if absent
|
||||
Phone string // empty if absent
|
||||
PlusOnes int
|
||||
}
|
||||
|
||||
// BulkImportResult reports the outcome of a single import call. The
|
||||
// SkippedEmails slice records the addresses we silently dropped because a
|
||||
// guest already exists on the event — useful for the success summary.
|
||||
type BulkImportResult struct {
|
||||
Added int
|
||||
Skipped int
|
||||
SkippedEmails []string
|
||||
}
|
||||
|
||||
// BulkImportGuests inserts up to len(rows) guest rows into the event in a
|
||||
// single transaction. Rows whose email matches an existing guest on the
|
||||
// event are skipped (idempotent re-imports). Within the batch, duplicate
|
||||
// emails after the first are also skipped. Either the entire batch
|
||||
// commits or none of it does.
|
||||
//
|
||||
// Empty email is treated as "no email" and not deduped — those rows
|
||||
// always insert (the host might be entering phone-only guests).
|
||||
func (r *GuestRepo) BulkImportGuests(ctx context.Context, eventID uuid.UUID, rows []BulkImportRow) (BulkImportResult, error) {
|
||||
res := BulkImportResult{}
|
||||
if len(rows) == 0 {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
tx, err := r.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Fetch existing emails on the event into a set for O(1) dedup.
|
||||
existing := map[string]struct{}{}
|
||||
exRows, err := tx.Query(ctx,
|
||||
`SELECT lower(email) FROM guests WHERE event_id = $1 AND email IS NOT NULL AND email <> ''`,
|
||||
eventID)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("load existing emails: %w", err)
|
||||
}
|
||||
for exRows.Next() {
|
||||
var e string
|
||||
if err := exRows.Scan(&e); err != nil {
|
||||
exRows.Close()
|
||||
return res, err
|
||||
}
|
||||
existing[e] = struct{}{}
|
||||
}
|
||||
exRows.Close()
|
||||
|
||||
const ins = `
|
||||
INSERT INTO guests (event_id, name, email, phone, plus_ones)
|
||||
VALUES ($1, $2, NULLIF($3, ''), NULLIF($4, ''), $5)
|
||||
`
|
||||
for _, row := range rows {
|
||||
email := strings.ToLower(strings.TrimSpace(row.Email))
|
||||
if email != "" {
|
||||
if _, dup := existing[email]; dup {
|
||||
res.Skipped++
|
||||
res.SkippedEmails = append(res.SkippedEmails, email)
|
||||
continue
|
||||
}
|
||||
existing[email] = struct{}{}
|
||||
}
|
||||
if _, err := tx.Exec(ctx, ins, eventID, row.Name, email, row.Phone, row.PlusOnes); err != nil {
|
||||
return BulkImportResult{}, fmt.Errorf("insert guest %q: %w", row.Name, err)
|
||||
}
|
||||
res.Added++
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return BulkImportResult{}, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func scanGuest(s rowScanner) (*domain.Guest, error) {
|
||||
var g domain.Guest
|
||||
err := s.Scan(
|
||||
@@ -100,6 +183,135 @@ func scanGuest(s rowScanner) (*domain.Guest, error) {
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
// UpdateGuestParams patches a guest. Nil fields are left untouched.
|
||||
// An empty string for Email / Phone clears the column to NULL, matching
|
||||
// the frontend "clear this field" UX.
|
||||
type UpdateGuestParams struct {
|
||||
Name *string
|
||||
Email *string
|
||||
Phone *string
|
||||
PlusOnes *int
|
||||
}
|
||||
|
||||
// Update applies the patch to (guestID, eventID). Event scoping in the
|
||||
// WHERE clause prevents a host from patching guests on another host's
|
||||
// event even if they guess the guest_id. Returns ErrGuestNotFound when
|
||||
// the guest doesn't exist on the event.
|
||||
func (r *GuestRepo) Update(ctx context.Context, eventID, guestID uuid.UUID, p UpdateGuestParams) (*domain.Guest, error) {
|
||||
sets := []string{}
|
||||
args := []any{guestID, eventID}
|
||||
add := func(col string, val any) {
|
||||
args = append(args, val)
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", col, len(args)))
|
||||
}
|
||||
if p.Name != nil {
|
||||
add("name", strings.TrimSpace(*p.Name))
|
||||
}
|
||||
if p.Email != nil {
|
||||
if strings.TrimSpace(*p.Email) == "" {
|
||||
sets = append(sets, "email = NULL")
|
||||
} else {
|
||||
add("email", strings.ToLower(strings.TrimSpace(*p.Email)))
|
||||
}
|
||||
}
|
||||
if p.Phone != nil {
|
||||
if strings.TrimSpace(*p.Phone) == "" {
|
||||
sets = append(sets, "phone = NULL")
|
||||
} else {
|
||||
add("phone", strings.TrimSpace(*p.Phone))
|
||||
}
|
||||
}
|
||||
if p.PlusOnes != nil {
|
||||
add("plus_ones", *p.PlusOnes)
|
||||
}
|
||||
if len(sets) == 0 {
|
||||
return r.Get(ctx, guestID)
|
||||
}
|
||||
q := fmt.Sprintf(`
|
||||
UPDATE guests SET %s
|
||||
WHERE id = $1 AND event_id = $2
|
||||
RETURNING id, event_id, name, email, phone, plus_ones, dietary_notes, table_number, created_at
|
||||
`, strings.Join(sets, ", "))
|
||||
g, err := scanGuest(r.pool.QueryRow(ctx, q, args...))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrGuestNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// Delete removes a guest from an event. Cascade-deletes any tokens,
|
||||
// rsvps, access_logs, and notifications tied to the guest. Event scoping
|
||||
// in the WHERE clause stops cross-tenant deletes.
|
||||
func (r *GuestRepo) Delete(ctx context.Context, eventID, guestID uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx,
|
||||
`DELETE FROM guests WHERE id = $1 AND event_id = $2`,
|
||||
guestID, eventID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrGuestNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GuestForInvitation is the minimum data the bulk-invite path needs about
|
||||
// each candidate. Pulled in a single query joined against tokens so the
|
||||
// caller knows up-front who's already received an invitation.
|
||||
type GuestForInvitation struct {
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Email string // empty when the guest has no email on file
|
||||
HasToken bool
|
||||
}
|
||||
|
||||
// ListGuestsForInvitation returns every guest on `eventID`, joined with
|
||||
// the tokens table so the caller can skip guests that already have one.
|
||||
// When `onlyIDs` is non-nil/non-empty, the result is filtered to that
|
||||
// subset (used for explicit selection in the bulk-send UI).
|
||||
func (r *GuestRepo) ListGuestsForInvitation(ctx context.Context, eventID uuid.UUID, onlyIDs []uuid.UUID) ([]GuestForInvitation, error) {
|
||||
var (
|
||||
rows pgx.Rows
|
||||
err error
|
||||
)
|
||||
if len(onlyIDs) == 0 {
|
||||
rows, err = r.pool.Query(ctx, `
|
||||
SELECT g.id, g.name, COALESCE(g.email,''),
|
||||
(t.id IS NOT NULL) AS has_token
|
||||
FROM guests g
|
||||
LEFT JOIN tokens t ON t.guest_id = g.id
|
||||
WHERE g.event_id = $1
|
||||
ORDER BY g.created_at ASC
|
||||
`, eventID)
|
||||
} else {
|
||||
rows, err = r.pool.Query(ctx, `
|
||||
SELECT g.id, g.name, COALESCE(g.email,''),
|
||||
(t.id IS NOT NULL) AS has_token
|
||||
FROM guests g
|
||||
LEFT JOIN tokens t ON t.guest_id = g.id
|
||||
WHERE g.event_id = $1 AND g.id = ANY($2)
|
||||
ORDER BY g.created_at ASC
|
||||
`, eventID, onlyIDs)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []GuestForInvitation
|
||||
for rows.Next() {
|
||||
var g GuestForInvitation
|
||||
if err := rows.Scan(&g.ID, &g.Name, &g.Email, &g.HasToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, g)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GuestWithRSVP is the dashboard view: a guest plus the RSVP submitted
|
||||
// against their token, if any. RSVP fields are nil when no response yet.
|
||||
type GuestWithRSVP struct {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user