feat: ship Tier 1 — auth, authz, rate limits, real notifications, CSV import, billing, backups/DR, privacy
Closes every block in docs/TIER1_PLAN.md from the Claude-scope side. The
homelab / cloud setup steps (SES verification, restore drill, lawyer-
drafted ToS) remain operator-owned but are unblocked.
Block A — Authentication
- Migration 0003: password_hash, email_verified, email_verification_tokens,
password_reset_tokens, refresh_tokens (with replaced_by family chain).
- Bcrypt hasher, HS256 JWT signer, single-use refresh tokens with rotation
+ replay-detection (revokes the family on reuse).
- /auth/signup, /login, /refresh, /logout, /verify-email,
/forgot-password, /reset-password — enumeration-safe.
- requireAuth middleware + GET /me.
- Frontend useAuth/useApi with auto-refresh-on-401, login/signup/verify/
forgot/reset pages, route-guard middleware.
Block B — Authorisation
- EventRepo.GetForHost; Update/Delete scoped by host_id.
- All host routes behind requireAuth + ownership; cross-tenant returns
404 (no enumeration). ?host_id removed.
- WS auth via short-lived single-use tickets (POST /auth/ws-ticket).
- Tests: TestCrossTenantIsolation — 9 probes.
Block C — Rate limiting
- Redis sliding-window via Lua (atomic ZADD+ZCARD+PEXPIRE).
- Per-route limits matching the plan (signup IP, login IP+email, RSVP/
access by token, events/guests/tokens by user_id).
- 429 with Retry-After header and JSON body.
- Auth lockout: 5 failed logins → account locked, only password reset
clears it.
- Frontend: useErrMessage normalises 429 + locked messaging.
Block D — Real notifications
- Migration 0004: provider_message_id, bounce_type, complained columns
+ unsubscribes (CITEXT) suppression table.
- Branded HTML + plaintext templates for verification, reset, invitation,
confirmation, reminder. Per-page templates avoid html/template's
contextual-escape collisions.
- Senders: SESv2, Twilio (SMS), SMTP (Mailpit-friendly), Resend HTTP.
- PickEmailSender priority Resend > SMTP > SES > Log — system boots
cleanly in dev with Mailpit; production flips one env var.
- Webhook endpoints (Twilio status + SES SNS) — bounces add to suppression;
signature verification stubbed pending creds.
- Auto-send: POST /tokens publishes invitation.send; notifier renders +
delivers via the configured backend; suppression list honoured.
- Bulk + per-row invitation flow: POST /events/{id}/guests/invitations/bulk
returns per-guest tokens so phone-only guests can be SMS'd manually.
- Unsubscribe: signed HMAC token (no TTL) + /unsubscribe/[token] page.
- WhatsApp Option A+: wa.me click-to-chat wizard with per-guest progress
tracking, isLikelyE164 validation, edit-from-wizard.
- Token rotate (POST /tokens/rotate) invalidates the old URL — used by
the regenerate-link flow.
- Mailpit added to docker-compose for dev inbox.
Block E — CSV import
- Streaming parser: tolerant header detection, UTF-8 BOM + UTF-16 LE/BE
decoding, row-level validation, 5,000-row cap.
- Strict E.164 phone validation with helpful error message.
- POST /preview + /import + GET /template; preview UI on event page;
atomic per-batch with dedup on existing emails.
Phone capture across UI
- PhoneInput component: country picker (~50 ISO codes) + national input +
live E.164 preview + inline length validation.
- Used in Add Guest and Edit Guest modals. Smart paste-handling extracts
country code from full E.164 strings.
Block F — Billing (Stripe)
- Migration 0005: subscriptions table (user_id → tier/status/period_end +
Stripe customer/sub ids). Partial unique index keeps one granting sub
per user.
- internal/billing: Tier + Limits model (Free 1/50, Pro 10/1000, Business
∞/5000), Stripe SDK wrapper with IgnoreAPIVersionMismatch for newer
account API versions.
- /billing/checkout-session, /billing/portal, /billing/status,
/webhooks/stripe (signature-verified, lifecycle events).
- Tier enforcement: 402 on POST /events, /guests, /import with
{error, reason, tier, used, limit, upgrade_url} body.
- Frontend: useBilling composable, /dashboard/billing page (current plan,
usage bars, tier cards), global UpgradeModal triggered by useApi's
402 interceptor.
- Customer portal kept for self-service cancel/payment-method changes.
Block G — Backups & DR (application side)
- Every migration has a tested .down.sql.
- TestMigrationRoundtrip applies all ups → all downs → all ups against a
fresh container; catches asymmetric down migrations.
- cmd/restore-verify: 28-check post-restore invariant tool (schema
presence, no orphans across 10 FK relationships, email uniqueness,
single-active subscription, row-count snapshot).
- docs/RUNBOOK_RESTORE.md: 9-step restore procedure with RTO/RPO
targets, drill instructions, rollback path.
Block H — Privacy compliance (application side)
- Migration 0006: deleted_at + terms_accepted_at + privacy_policy_accepted_at
on users. Partial index on email for live-only uniqueness.
- GET /me/data-export — synchronous JSON dump (user, events, guests,
tokens, rsvps, access_logs, notifications).
- DELETE /me — soft-delete with PII scrub + refresh-token revocation;
re-signup with same email works.
- POST /me/accept-terms — idempotent consent recording.
- Frontend /privacy + /terms placeholder pages with substantive (pending
legal review) copy; footer links; signup terms checkbox; TermsGateModal
for accounts created before the rollout; export + delete buttons on
/dashboard/billing.
Tests
- All migrations verified up/down/up.
- Integration suite: TestE2EHappyPath, TestAuthFlow, TestCrossTenantIsolation,
TestRateLimitSignup, TestLoginLockout, TestUnsubscribeFlow,
TestSESBounceWebhook, TestTwilioStatusWebhook, TestCsvImportFlow,
TestCsvImportAtomicRollback, TestBulkIssueInvitations, TestBulkIssueExplicitSubset,
TestTokenIssuePublishesInvitation, TestTokenIssueWithoutGuestEmailSkipsInvitation,
TestGuestUpdate, TestGuestDelete, TestTokenRotate, TestSMTPSenderAgainstMailpit,
TestFreeTierEventLimit, TestFreeTierGuestLimit, TestBusinessTierBypassesLimits,
TestDataExport, TestDeleteMe, TestAcceptTerms, TestMigrationRoundtrip.
Full suite runs in ~120s against real Postgres + NATS + Redis + Mailpit.
- Unit suite green across internal/auth, internal/csvimport,
internal/notification, internal/ratelimit, internal/domain.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
@@ -42,16 +40,15 @@ type activityItem struct {
|
||||
// for an event, sorted newest first. Frontends use this on dashboard mount
|
||||
// to backfill the live monitor with history.
|
||||
func (h *activityHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := h.events.Get(r.Context(), eventID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,557 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/ratelimit"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
const refreshCookieName = "gg_refresh"
|
||||
|
||||
type authHandler struct {
|
||||
logger *slog.Logger
|
||||
users *storage.UserRepo
|
||||
verifications *storage.EmailVerificationRepo
|
||||
resets *storage.PasswordResetRepo
|
||||
refreshes *storage.RefreshTokenRepo
|
||||
hasher *auth.PasswordHasher
|
||||
signer *auth.JWTSigner
|
||||
emails auth.EmailSender
|
||||
lockout *auth.LockoutTracker
|
||||
limiter *ratelimit.Limiter
|
||||
|
||||
publicBaseURL string
|
||||
emailVerificationTTL time.Duration
|
||||
passwordResetTTL time.Duration
|
||||
refreshTTL time.Duration
|
||||
cookieDomain string
|
||||
cookieSecure bool
|
||||
}
|
||||
|
||||
type authHandlerDeps struct {
|
||||
Logger *slog.Logger
|
||||
Users *storage.UserRepo
|
||||
Verifications *storage.EmailVerificationRepo
|
||||
Resets *storage.PasswordResetRepo
|
||||
Refreshes *storage.RefreshTokenRepo
|
||||
Hasher *auth.PasswordHasher
|
||||
Signer *auth.JWTSigner
|
||||
Emails auth.EmailSender
|
||||
Lockout *auth.LockoutTracker
|
||||
Limiter *ratelimit.Limiter
|
||||
|
||||
PublicBaseURL string
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
RefreshTTL time.Duration
|
||||
CookieDomain string
|
||||
CookieSecure bool
|
||||
}
|
||||
|
||||
func newAuthHandler(d authHandlerDeps) *authHandler {
|
||||
return &authHandler{
|
||||
logger: d.Logger,
|
||||
users: d.Users,
|
||||
verifications: d.Verifications,
|
||||
resets: d.Resets,
|
||||
refreshes: d.Refreshes,
|
||||
hasher: d.Hasher,
|
||||
signer: d.Signer,
|
||||
emails: d.Emails,
|
||||
lockout: d.Lockout,
|
||||
limiter: d.Limiter,
|
||||
publicBaseURL: strings.TrimRight(d.PublicBaseURL, "/"),
|
||||
emailVerificationTTL: d.EmailVerificationTTL,
|
||||
passwordResetTTL: d.PasswordResetTTL,
|
||||
refreshTTL: d.RefreshTTL,
|
||||
cookieDomain: d.CookieDomain,
|
||||
cookieSecure: d.CookieSecure,
|
||||
}
|
||||
}
|
||||
|
||||
// --- request/response types ---
|
||||
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
AcceptTerms bool `json:"accept_terms"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type verifyEmailRequest struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type forgotPasswordRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type resetPasswordRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type authSuccess struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
User *domain.User `json:"user"`
|
||||
}
|
||||
|
||||
// --- handlers ---
|
||||
|
||||
// POST /auth/signup
|
||||
func (h *authHandler) signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "email is invalid")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
hash, err := h.hasher.Hash(req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("hash password", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.Create(r.Context(), storage.CreateUserParams{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
PasswordHash: hash,
|
||||
AcceptTerms: req.AcceptTerms,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEmailTaken) {
|
||||
// Don't leak which addresses are registered. Still return 201 and
|
||||
// trigger a "if-you-already-have-an-account" email asynchronously
|
||||
// (skipped for the stub). On real auth this should send a "you
|
||||
// tried to sign up again, here's a reset link" email.
|
||||
h.logger.Info("signup attempted with existing email", "email", req.Email)
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("create user", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.sendVerificationEmail(r.Context(), u); err != nil {
|
||||
h.logger.Error("send verification email", "err", err, "user_id", u.ID)
|
||||
// Don't fail the signup — user can request a resend.
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
|
||||
}
|
||||
|
||||
// POST /auth/login
|
||||
func (h *authHandler) login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Email == "" || req.Password == "" {
|
||||
writeError(w, http.StatusBadRequest, "email and password required")
|
||||
return
|
||||
}
|
||||
|
||||
// Per-(IP + email) sliding-window — 10 per 5 minutes per the plan.
|
||||
if !h.checkRate(w, r, "login", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
|
||||
10, 5*time.Minute) {
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.GetByEmail(r.Context(), req.Email)
|
||||
if err != nil || u.PasswordHash == "" {
|
||||
_, _ = h.lockout.RecordFailure(r.Context(), req.Email, nil)
|
||||
writeError(w, http.StatusUnauthorized, "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
// If the account is already locked, reject before doing a bcrypt compare.
|
||||
locked, _ := h.lockout.IsLocked(r.Context(), u.ID)
|
||||
if locked {
|
||||
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.hasher.Verify(u.PasswordHash, req.Password); err != nil {
|
||||
locked, _ := h.lockout.RecordFailure(r.Context(), req.Email, &u.ID)
|
||||
if locked {
|
||||
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid email or password")
|
||||
return
|
||||
}
|
||||
if !u.EmailVerified {
|
||||
writeError(w, http.StatusForbidden, "email not verified")
|
||||
return
|
||||
}
|
||||
|
||||
h.lockout.ClearOnSuccess(r.Context(), req.Email)
|
||||
if err := h.issueSession(w, r, u); err != nil {
|
||||
h.logger.Error("issue session", "err", err, "user_id", u.ID)
|
||||
writeError(w, http.StatusInternalServerError, "failed to start session")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// checkRate consults the limiter (when one is configured) and writes a 429
|
||||
// response if the budget is exhausted. Returns false if the caller should
|
||||
// stop handling the request.
|
||||
func (h *authHandler) checkRate(w http.ResponseWriter, r *http.Request, name, key string, limit int, window time.Duration) bool {
|
||||
if h.limiter == nil || key == "" {
|
||||
return true
|
||||
}
|
||||
res, err := h.limiter.Allow(r.Context(), name, key, limit, window)
|
||||
if err != nil {
|
||||
h.logger.Warn("ratelimit error (failing open)", "rule", name, "err", err)
|
||||
return true
|
||||
}
|
||||
if !res.Allowed {
|
||||
retry := int(res.RetryAfter.Round(time.Second).Seconds())
|
||||
if retry < 1 {
|
||||
retry = 1
|
||||
}
|
||||
w.Header().Set("Retry-After", strconv.Itoa(retry))
|
||||
writeJSON(w, http.StatusTooManyRequests, map[string]any{
|
||||
"error": "rate limit exceeded",
|
||||
"retry_after": retry,
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// POST /auth/refresh
|
||||
func (h *authHandler) refresh(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(refreshCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing refresh token")
|
||||
return
|
||||
}
|
||||
oldHash := auth.HashOpaque(cookie.Value)
|
||||
|
||||
existing, err := h.refreshes.Get(r.Context(), oldHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrAuthTokenNotFound) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "invalid refresh token")
|
||||
return
|
||||
}
|
||||
h.logger.Error("lookup refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
if existing.RevokedAt != nil {
|
||||
// Replay of a revoked token. Revoke the family.
|
||||
_ = h.refreshes.RevokeAllForUser(r.Context(), existing.UserID)
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token reused")
|
||||
return
|
||||
}
|
||||
if time.Now().After(existing.ExpiresAt) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token expired")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.GetByID(r.Context(), existing.UserID)
|
||||
if err != nil {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
newRaw, newHash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
h.logger.Error("mint refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
exp := time.Now().Add(h.refreshTTL)
|
||||
if err := h.refreshes.Rotate(r.Context(), oldHash, storage.CreateRefreshTokenParams{
|
||||
Hash: newHash,
|
||||
UserID: u.ID,
|
||||
ExpiresAt: exp,
|
||||
UserAgent: r.UserAgent(),
|
||||
IPAddress: clientIP(r),
|
||||
}); err != nil {
|
||||
if errors.Is(err, domain.ErrRefreshTokenRevoked) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token reused")
|
||||
return
|
||||
}
|
||||
h.logger.Error("rotate refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
|
||||
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
|
||||
if err != nil {
|
||||
h.logger.Error("sign access", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
h.setRefreshCookie(w, newRaw, exp)
|
||||
writeJSON(w, http.StatusOK, authSuccess{
|
||||
AccessToken: access,
|
||||
ExpiresAt: accessExp,
|
||||
User: u,
|
||||
})
|
||||
}
|
||||
|
||||
// POST /auth/logout
|
||||
func (h *authHandler) logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(refreshCookieName)
|
||||
if err == nil && cookie.Value != "" {
|
||||
_ = h.refreshes.Revoke(r.Context(), auth.HashOpaque(cookie.Value))
|
||||
}
|
||||
h.clearRefreshCookie(w)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /auth/verify-email
|
||||
func (h *authHandler) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
||||
var req verifyEmailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token required")
|
||||
return
|
||||
}
|
||||
uid, err := h.verifications.Consume(r.Context(), auth.HashOpaque(req.Token))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrAuthTokenNotFound):
|
||||
writeError(w, http.StatusBadRequest, "invalid token")
|
||||
case errors.Is(err, domain.ErrAuthTokenConsumed):
|
||||
writeError(w, http.StatusBadRequest, "token already used")
|
||||
case errors.Is(err, domain.ErrAuthTokenExpired):
|
||||
writeError(w, http.StatusBadRequest, "token expired")
|
||||
default:
|
||||
h.logger.Error("consume verification", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "verification failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.users.MarkEmailVerified(r.Context(), uid); err != nil {
|
||||
h.logger.Error("mark verified", "err", err, "user_id", uid)
|
||||
writeError(w, http.StatusInternalServerError, "verification failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "verified"})
|
||||
}
|
||||
|
||||
// POST /auth/forgot-password
|
||||
func (h *authHandler) forgotPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req forgotPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if !h.checkRate(w, r, "forgot_password", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
|
||||
3, time.Hour) {
|
||||
return
|
||||
}
|
||||
// Always respond 202 to avoid leaking whether the email exists.
|
||||
defer func() { writeJSON(w, http.StatusAccepted, map[string]string{"status": "if_known_email_sent"}) }()
|
||||
|
||||
u, err := h.users.GetByEmail(r.Context(), req.Email)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
h.logger.Error("mint reset", "err", err)
|
||||
return
|
||||
}
|
||||
exp := time.Now().Add(h.passwordResetTTL)
|
||||
if err := h.resets.Create(r.Context(), u.ID, hash, exp); err != nil {
|
||||
h.logger.Error("persist reset", "err", err)
|
||||
return
|
||||
}
|
||||
link := h.publicBaseURL + "/reset-password/" + url.PathEscape(raw)
|
||||
if err := h.emails.SendPasswordReset(r.Context(), u.Email, u.Name, link); err != nil {
|
||||
h.logger.Error("send reset email", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// POST /auth/reset-password
|
||||
func (h *authHandler) resetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req resetPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token and new_password required")
|
||||
return
|
||||
}
|
||||
newHash, err := h.hasher.Hash(req.NewPassword)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("hash password", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
return
|
||||
}
|
||||
uid, err := h.resets.Consume(r.Context(), auth.HashOpaque(req.Token))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrAuthTokenNotFound):
|
||||
writeError(w, http.StatusBadRequest, "invalid token")
|
||||
case errors.Is(err, domain.ErrAuthTokenConsumed):
|
||||
writeError(w, http.StatusBadRequest, "token already used")
|
||||
case errors.Is(err, domain.ErrAuthTokenExpired):
|
||||
writeError(w, http.StatusBadRequest, "token expired")
|
||||
default:
|
||||
h.logger.Error("consume reset", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.users.UpdatePasswordHash(r.Context(), uid, newHash); err != nil {
|
||||
h.logger.Error("update password", "err", err, "user_id", uid)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
return
|
||||
}
|
||||
// Invalidate all existing sessions.
|
||||
_ = h.refreshes.RevokeAllForUser(r.Context(), uid)
|
||||
// Resetting the password is the canonical "unlock" path for the
|
||||
// account lockout that triggers after repeated bad-credential attempts.
|
||||
if u, err := h.users.GetByID(r.Context(), uid); err == nil {
|
||||
_ = h.lockout.ClearForUser(r.Context(), uid, u.Email)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "password_reset"})
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func (h *authHandler) sendVerificationEmail(ctx context.Context, u *domain.User) error {
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifications.Create(ctx, u.ID, hash, time.Now().Add(h.emailVerificationTTL)); err != nil {
|
||||
return err
|
||||
}
|
||||
link := h.publicBaseURL + "/verify-email?token=" + url.QueryEscape(raw)
|
||||
return h.emails.SendVerification(ctx, u.Email, u.Name, link)
|
||||
}
|
||||
|
||||
func (h *authHandler) issueSession(w http.ResponseWriter, r *http.Request, u *domain.User) error {
|
||||
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
refreshExp := time.Now().Add(h.refreshTTL)
|
||||
if err := h.refreshes.Create(r.Context(), storage.CreateRefreshTokenParams{
|
||||
Hash: hash,
|
||||
UserID: u.ID,
|
||||
ExpiresAt: refreshExp,
|
||||
UserAgent: r.UserAgent(),
|
||||
IPAddress: clientIP(r),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
h.setRefreshCookie(w, raw, refreshExp)
|
||||
writeJSON(w, http.StatusOK, authSuccess{
|
||||
AccessToken: access,
|
||||
ExpiresAt: accessExp,
|
||||
User: u,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *authHandler) setRefreshCookie(w http.ResponseWriter, value string, expires time.Time) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: refreshCookieName,
|
||||
Value: value,
|
||||
Path: "/auth",
|
||||
Domain: h.cookieDomain,
|
||||
Expires: expires,
|
||||
MaxAge: int(time.Until(expires).Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: h.cookieSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *authHandler) clearRefreshCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: refreshCookieName,
|
||||
Value: "",
|
||||
Path: "/auth",
|
||||
Domain: h.cookieDomain,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: h.cookieSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// --- requireAuth middleware ---
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const userIDCtxKey ctxKey = iota
|
||||
|
||||
func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(userIDCtxKey).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func requireAuth(signer *auth.JWTSigner) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(h, "Bearer ") {
|
||||
writeError(w, http.StatusUnauthorized, "missing bearer token")
|
||||
return
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(h, "Bearer "))
|
||||
claims, err := signer.Parse(raw)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrExpiredJWT) {
|
||||
writeError(w, http.StatusUnauthorized, "token expired")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid token")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userIDCtxKey, claims.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// hostFromContext returns the authed user's id, or writes 401 and returns
|
||||
// false. Used by host-facing handlers as the first line in the function.
|
||||
func hostFromContext(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthenticated")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
return uid, true
|
||||
}
|
||||
|
||||
// requireEventOwner fetches the event and confirms the authed user owns it.
|
||||
// On mismatch (or missing event) it returns 404 — never 403 — so a cross-
|
||||
// tenant probe cannot tell the difference between "event doesn't exist" and
|
||||
// "exists but belongs to someone else".
|
||||
func requireEventOwner(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
events *storage.EventRepo,
|
||||
eventID, hostID uuid.UUID,
|
||||
) (*domain.Event, bool) {
|
||||
ev, err := events.GetForHost(r.Context(), eventID, hostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return nil, false
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
return nil, false
|
||||
}
|
||||
return ev, true
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type billingHandler struct {
|
||||
logger *slog.Logger
|
||||
stripe *billing.Client
|
||||
users *storage.UserRepo
|
||||
subscriptions *storage.SubscriptionRepo
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
type checkoutSessionRequest struct {
|
||||
Tier string `json:"tier"`
|
||||
}
|
||||
|
||||
type checkoutSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// POST /billing/checkout-session — returns the Stripe Checkout URL the
|
||||
// frontend redirects the host to. Mints a Stripe customer on first use
|
||||
// and persists it so repeat calls reuse the same customer.
|
||||
func (h *billingHandler) checkoutSession(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.stripeEnabled(w) {
|
||||
return
|
||||
}
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req checkoutSessionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
tier := billing.Tier(strings.ToLower(req.Tier))
|
||||
if tier != billing.TierPro && tier != billing.TierBusiness {
|
||||
writeError(w, http.StatusBadRequest, "tier must be 'pro' or 'business'")
|
||||
return
|
||||
}
|
||||
price, err := h.stripe.PriceFor(tier)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "this tier is not configured yet — contact support")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.users.GetByID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
|
||||
existingCustomerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load billing record")
|
||||
return
|
||||
}
|
||||
customerID, err := h.stripe.CreateOrGetCustomer(hostID.String(), user.Email, user.Name, existingCustomerID)
|
||||
if err != nil {
|
||||
h.logger.Error("stripe customer", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe customer error")
|
||||
return
|
||||
}
|
||||
if existingCustomerID == "" {
|
||||
// First time — write a placeholder row so the customer id sticks.
|
||||
if _, err := h.subscriptions.Upsert(r.Context(), storage.UpsertParams{
|
||||
UserID: hostID,
|
||||
StripeCustomerID: customerID,
|
||||
}); err != nil {
|
||||
h.logger.Error("upsert sub placeholder", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
base := strings.TrimRight(h.publicBaseURL, "/")
|
||||
url, err := h.stripe.CreateCheckoutSession(billing.CheckoutSessionParams{
|
||||
CustomerID: customerID,
|
||||
PriceID: price,
|
||||
SuccessURL: base + "/dashboard?billing=success",
|
||||
CancelURL: base + "/dashboard?billing=cancelled",
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("stripe checkout session", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe checkout error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, checkoutSessionResponse{URL: url})
|
||||
}
|
||||
|
||||
type portalSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// POST /billing/portal — returns the customer portal URL so the user
|
||||
// can manage their payment method, view invoices, or cancel. 404 when
|
||||
// the user has no Stripe customer yet (they're still on free).
|
||||
func (h *billingHandler) portalSession(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.stripeEnabled(w) {
|
||||
return
|
||||
}
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
customerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load billing record")
|
||||
return
|
||||
}
|
||||
if customerID == "" {
|
||||
writeError(w, http.StatusNotFound, "no billing account yet — subscribe first")
|
||||
return
|
||||
}
|
||||
url, err := h.stripe.CreatePortalSession(customerID, strings.TrimRight(h.publicBaseURL, "/")+"/dashboard")
|
||||
if err != nil {
|
||||
h.logger.Error("stripe portal", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe portal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, portalSessionResponse{URL: url})
|
||||
}
|
||||
|
||||
type subscriptionStatusResponse struct {
|
||||
Tier string `json:"tier"`
|
||||
Status string `json:"status"`
|
||||
CurrentPeriodEnd string `json:"current_period_end,omitempty"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
Limits struct {
|
||||
EventsPerMonth int `json:"events_per_month"`
|
||||
GuestsPerEvent int `json:"guests_per_event"`
|
||||
} `json:"limits"`
|
||||
Usage struct {
|
||||
EventsThisMonth int `json:"events_this_month"`
|
||||
} `json:"usage"`
|
||||
PortalAvailable bool `json:"portal_available"`
|
||||
}
|
||||
|
||||
// GET /billing/status — returns the host's current tier + limits +
|
||||
// usage. The frontend uses this to render the billing page and the
|
||||
// 402-modal copy ("you used X of Y events this month").
|
||||
func (h *billingHandler) status(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tier := billing.TierFree
|
||||
status := "active"
|
||||
var periodEnd string
|
||||
cancelAtPeriodEnd := false
|
||||
portalAvailable := false
|
||||
|
||||
if h.subscriptions != nil {
|
||||
sub, err := h.subscriptions.GetActiveByUser(r.Context(), hostID)
|
||||
switch {
|
||||
case err == nil:
|
||||
tier = billing.Tier(sub.Tier)
|
||||
status = sub.Status
|
||||
cancelAtPeriodEnd = sub.CancelAtPeriodEnd
|
||||
if sub.CurrentPeriodEnd != nil {
|
||||
periodEnd = sub.CurrentPeriodEnd.Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
portalAvailable = sub.StripeCustomerID != ""
|
||||
case errors.Is(err, storage.ErrSubscriptionNotFound):
|
||||
// Free tier — leave defaults.
|
||||
default:
|
||||
writeError(w, http.StatusInternalServerError, "failed to load subscription")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limits := billing.LimitsFor(tier)
|
||||
events, _ := h.subscriptions.CountEventsInCurrentMonth(r.Context(), hostID)
|
||||
|
||||
resp := subscriptionStatusResponse{
|
||||
Tier: string(tier),
|
||||
Status: status,
|
||||
CurrentPeriodEnd: periodEnd,
|
||||
CancelAtPeriodEnd: cancelAtPeriodEnd,
|
||||
PortalAvailable: portalAvailable,
|
||||
}
|
||||
resp.Limits.EventsPerMonth = limits.EventsPerMonth
|
||||
resp.Limits.GuestsPerEvent = limits.GuestsPerEvent
|
||||
resp.Usage.EventsThisMonth = events
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// stripeEnabled returns true if the billing client is configured, else
|
||||
// writes 503 and returns false. The /billing/status path skips this so
|
||||
// the frontend can render a "free tier" page in dev environments.
|
||||
func (h *billingHandler) stripeEnabled(w http.ResponseWriter) bool {
|
||||
if h.stripe == nil || !h.stripe.Enabled() {
|
||||
writeError(w, http.StatusServiceUnavailable, "billing is not configured on this instance")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// tierEnforcer wraps the SubscriptionRepo with policy decisions. Lives
|
||||
// here (not in storage) because the policy is HTTP-shaped: we map the
|
||||
// outcome to 402 + an upgrade URL.
|
||||
type tierEnforcer struct {
|
||||
subs *storage.SubscriptionRepo
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
func newTierEnforcer(subs *storage.SubscriptionRepo, publicBaseURL string) *tierEnforcer {
|
||||
return &tierEnforcer{subs: subs, publicBaseURL: publicBaseURL}
|
||||
}
|
||||
|
||||
// currentTier returns the host's effective tier. ErrSubscriptionNotFound
|
||||
// means "no granting subscription on file" → free. Other DB errors
|
||||
// bubble up.
|
||||
func (e *tierEnforcer) currentTier(ctx context.Context, hostID uuid.UUID) (billing.Tier, error) {
|
||||
if e == nil || e.subs == nil {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
sub, err := e.subs.GetActiveByUser(ctx, hostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, storage.ErrSubscriptionNotFound) {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if !billing.StatusGrantsAccess(sub.Status) {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
tier := billing.Tier(sub.Tier)
|
||||
if !tier.Valid() {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
return tier, nil
|
||||
}
|
||||
|
||||
// allowEventCreate verifies the host's monthly event budget. Returns
|
||||
// true when the request may proceed. On denial it writes a 402 with the
|
||||
// upgrade hint and returns false.
|
||||
func (e *tierEnforcer) allowEventCreate(w http.ResponseWriter, r *http.Request, hostID uuid.UUID) bool {
|
||||
if e == nil || e.subs == nil {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).EventsPerMonth
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountEventsInCurrentMonth(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count events")
|
||||
return false
|
||||
}
|
||||
if used >= limit {
|
||||
e.writePaymentRequired(w, "events_per_month", tier, used, limit,
|
||||
"You've reached your monthly event limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// allowGuestCreate verifies the per-event guest cap. Same shape as
|
||||
// allowEventCreate.
|
||||
func (e *tierEnforcer) allowGuestCreate(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID) bool {
|
||||
if e == nil || e.subs == nil {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).GuestsPerEvent
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count guests")
|
||||
return false
|
||||
}
|
||||
if used >= limit {
|
||||
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
|
||||
"This event has reached the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// allowGuestImport is the CSV-import variant: check the cap against
|
||||
// existing + incoming row count up-front, before we start the
|
||||
// transaction. Dedup may shrink the actual insert count later — that's
|
||||
// OK, we just stay on the safe side.
|
||||
func (e *tierEnforcer) allowGuestImport(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID, incoming int) bool {
|
||||
if e == nil || e.subs == nil || incoming == 0 {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).GuestsPerEvent
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count guests")
|
||||
return false
|
||||
}
|
||||
if used+incoming > limit {
|
||||
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
|
||||
"This import would exceed the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type paymentRequiredBody struct {
|
||||
Error string `json:"error"`
|
||||
Reason string `json:"reason"`
|
||||
Tier string `json:"tier"`
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
UpgradeURL string `json:"upgrade_url"`
|
||||
}
|
||||
|
||||
func (e *tierEnforcer) writePaymentRequired(w http.ResponseWriter, reason string, tier billing.Tier, used, limit int, msg string) {
|
||||
body := paymentRequiredBody{
|
||||
Error: msg,
|
||||
Reason: reason,
|
||||
Tier: string(tier),
|
||||
Used: used,
|
||||
Limit: limit,
|
||||
UpgradeURL: strings.TrimRight(e.publicBaseURL, "/") + "/dashboard/billing",
|
||||
}
|
||||
writeJSON(w, http.StatusPaymentRequired, body)
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/csvimport"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
// 1 MB cap on uploads. With ~200 bytes per row that's ~5,000 guests —
|
||||
// matches the row cap in csvimport.DefaultMaxRows.
|
||||
csvMaxBytes = 1 << 20
|
||||
)
|
||||
|
||||
type csvImportHandler struct {
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type importResponse struct {
|
||||
Added int `json:"added"`
|
||||
Skipped int `json:"skipped"`
|
||||
SkippedEmails []string `json:"skipped_emails,omitempty"`
|
||||
Errors []csvimport.RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"`
|
||||
}
|
||||
|
||||
type previewResponse struct {
|
||||
Rows []csvimport.Row `json:"rows"`
|
||||
Errors []csvimport.RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/import/preview — parse + validate but don't write.
|
||||
// Used by the frontend to show a "is this what you meant?" table before commit.
|
||||
func (h *csvImportHandler) preview(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
body, ok := readCSVUpload(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
res, err := csvimport.Parse(body, csvimport.Options{})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, previewResponse{
|
||||
Rows: res.Rows,
|
||||
Errors: res.Errors,
|
||||
TotalCount: res.TotalCount,
|
||||
})
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/import — parse, validate, and commit valid rows
|
||||
// in a single transaction. Rows with row-level errors are reported back
|
||||
// but don't prevent the rest from importing.
|
||||
func (h *csvImportHandler) commit(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
body, ok := readCSVUpload(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
parsed, err := csvimport.Parse(body, csvimport.Options{})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Plan enforcement: prevent an import that, even if perfectly
|
||||
// dedup-free, would exceed the per-event guest cap. We check upfront
|
||||
// (current count + parsed-row count) and reject with 402 if it'd
|
||||
// overflow. False positives on dedup-heavy CSVs are acceptable —
|
||||
// host can dedupe and re-upload.
|
||||
if !h.enforcer.allowGuestImport(w, r, hostID, eventID, len(parsed.Rows)) {
|
||||
return
|
||||
}
|
||||
|
||||
rows := make([]storage.BulkImportRow, 0, len(parsed.Rows))
|
||||
for _, r := range parsed.Rows {
|
||||
rows = append(rows, storage.BulkImportRow{
|
||||
Name: r.Name,
|
||||
Email: r.Email,
|
||||
Phone: r.Phone,
|
||||
PlusOnes: r.PlusOnes,
|
||||
})
|
||||
}
|
||||
|
||||
res, err := h.guests.BulkImportGuests(r.Context(), eventID, rows)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "import failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, importResponse{
|
||||
Added: res.Added,
|
||||
Skipped: res.Skipped,
|
||||
SkippedEmails: res.SkippedEmails,
|
||||
Errors: parsed.Errors,
|
||||
TotalCount: parsed.TotalCount,
|
||||
})
|
||||
}
|
||||
|
||||
// GET /events/{id}/guests/import/template — download a sample CSV. Auth is
|
||||
// applied at the route level; ownership is verified so an attacker can't
|
||||
// probe the existence of an event by hitting this endpoint.
|
||||
func (h *csvImportHandler) template(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-import-template.csv"`)
|
||||
_, _ = w.Write([]byte(csvimport.TemplateCSV))
|
||||
}
|
||||
|
||||
// readCSVUpload returns the multipart file body (capped at csvMaxBytes) or
|
||||
// writes an error and returns (nil, false). Accepted shapes:
|
||||
//
|
||||
// - multipart/form-data with a "file" field (the drag-drop UI uses this)
|
||||
// - any other Content-Type — the raw body is treated as the CSV (curl-friendly)
|
||||
func readCSVUpload(w http.ResponseWriter, r *http.Request) (io.ReadCloser, bool) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, csvMaxBytes)
|
||||
if err := r.ParseMultipartForm(csvMaxBytes); err == nil && r.MultipartForm != nil {
|
||||
files := r.MultipartForm.File["file"]
|
||||
if len(files) == 0 {
|
||||
writeError(w, http.StatusBadRequest, `form field "file" is required`)
|
||||
return nil, false
|
||||
}
|
||||
f, err := files[0].Open()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "cannot read uploaded file")
|
||||
return nil, false
|
||||
}
|
||||
return f, true
|
||||
} else if err != nil && !errors.Is(err, http.ErrNotMultipart) {
|
||||
writeError(w, http.StatusBadRequest, "invalid multipart body")
|
||||
return nil, false
|
||||
}
|
||||
// Fall through: raw body as CSV.
|
||||
return r.Body, true
|
||||
}
|
||||
+39
-28
@@ -15,11 +15,11 @@ import (
|
||||
)
|
||||
|
||||
type eventHandler struct {
|
||||
repo *storage.EventRepo
|
||||
repo *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type createEventRequest struct {
|
||||
HostID string `json:"host_id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
EventDate time.Time `json:"event_date"`
|
||||
@@ -32,6 +32,14 @@ type createEventRequest struct {
|
||||
var slugRe = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
|
||||
|
||||
func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !h.enforcer.allowEventCreate(w, r, hostID) {
|
||||
return
|
||||
}
|
||||
|
||||
var req createEventRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
@@ -51,12 +59,6 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hostID, err := uuid.Parse(req.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
|
||||
return
|
||||
}
|
||||
|
||||
status := domain.EventStatus(req.Status)
|
||||
if status == "" {
|
||||
status = domain.EventStatusDraft
|
||||
@@ -89,37 +91,31 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *eventHandler) get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ev, err := h.repo.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
ev, ok := requireEventOwner(w, r, h.repo, id, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ev)
|
||||
}
|
||||
|
||||
func (h *eventHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
limit := atoiOr(q.Get("limit"), 50)
|
||||
offset := atoiOr(q.Get("offset"), 0)
|
||||
|
||||
var hostID uuid.UUID
|
||||
if v := q.Get("host_id"); v != "" {
|
||||
parsed, err := uuid.Parse(v)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
|
||||
return
|
||||
}
|
||||
hostID = parsed
|
||||
}
|
||||
|
||||
events, err := h.repo.List(r.Context(), hostID, limit, offset)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to list events")
|
||||
@@ -146,6 +142,10 @@ type updateEventRequest struct {
|
||||
}
|
||||
|
||||
func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
@@ -180,7 +180,7 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
params.Status = &s
|
||||
}
|
||||
|
||||
ev, err := h.repo.Update(r.Context(), id, params)
|
||||
ev, err := h.repo.Update(r.Context(), id, hostID, params)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrEventNotFound):
|
||||
@@ -196,11 +196,15 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.repo.Delete(r.Context(), id); err != nil {
|
||||
if err := h.repo.Delete(r.Context(), id, hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
@@ -212,7 +216,14 @@ func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func parseIDParam(w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, bool) {
|
||||
raw := r.PathValue(name)
|
||||
return parseRawUUID(w, name, r.PathValue(name))
|
||||
}
|
||||
|
||||
func parseRawUUID(w http.ResponseWriter, name, raw string) (uuid.UUID, bool) {
|
||||
if raw == "" {
|
||||
writeError(w, http.StatusBadRequest, name+" is required")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, name+" must be a valid uuid")
|
||||
|
||||
+107
-9
@@ -10,8 +10,9 @@ import (
|
||||
)
|
||||
|
||||
type guestHandler struct {
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type createGuestRequest struct {
|
||||
@@ -24,16 +25,18 @@ type createGuestRequest struct {
|
||||
}
|
||||
|
||||
func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := h.events.Get(r.Context(), eventID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
if !h.enforcer.allowGuestCreate(w, r, hostID, eventID) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -67,11 +70,106 @@ func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, g)
|
||||
}
|
||||
|
||||
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
type updateGuestRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Email *string `json:"email"`
|
||||
Phone *string `json:"phone"`
|
||||
PlusOnes *int `json:"plus_ones"`
|
||||
}
|
||||
|
||||
// PATCH /events/{id}/guests/{guest_id} — patch a guest's contact info.
|
||||
// Fields omitted from the body are left untouched. Empty strings for
|
||||
// email/phone clear those columns to NULL.
|
||||
func (h *guestHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req updateGuestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Name != nil && *req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name cannot be empty")
|
||||
return
|
||||
}
|
||||
if req.PlusOnes != nil && *req.PlusOnes < 0 {
|
||||
writeError(w, http.StatusBadRequest, "plus_ones must be >= 0")
|
||||
return
|
||||
}
|
||||
|
||||
g, err := h.guests.Update(r.Context(), eventID, guestID, storage.UpdateGuestParams{
|
||||
Name: req.Name,
|
||||
Email: req.Email,
|
||||
Phone: req.Phone,
|
||||
PlusOnes: req.PlusOnes,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to update guest")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, g)
|
||||
}
|
||||
|
||||
// DELETE /events/{id}/guests/{guest_id} — remove a guest from an event.
|
||||
// Cascade-deletes their token, rsvp, access logs, notifications.
|
||||
func (h *guestHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
if err := h.guests.Delete(r.Context(), eventID, guestID); err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete guest")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
limit := atoiOr(q.Get("limit"), 100)
|
||||
offset := atoiOr(q.Get("offset"), 0)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type meHandler struct {
|
||||
users *storage.UserRepo
|
||||
}
|
||||
|
||||
// GET /me — returns the authenticated user. Used by the frontend to bootstrap
|
||||
// after a page reload (with a fresh access token from /auth/refresh).
|
||||
func (h *meHandler) get(w http.ResponseWriter, r *http.Request) {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthenticated")
|
||||
return
|
||||
}
|
||||
u, err := h.users.GetByID(r.Context(), uid)
|
||||
if err != nil {
|
||||
if err == domain.ErrUserNotFound {
|
||||
writeError(w, http.StatusUnauthorized, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, u)
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// privacyHandler holds the GDPR-style "your data, your choice" endpoints:
|
||||
// data export, account deletion, and terms-acceptance recording.
|
||||
type privacyHandler struct {
|
||||
logger *slog.Logger
|
||||
users *storage.UserRepo
|
||||
events *storage.EventRepo
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
rsvps *storage.RSVPRepo
|
||||
access *storage.AccessLogRepo
|
||||
notifs *storage.DB // raw pool access for the export queries
|
||||
refresh *storage.RefreshTokenRepo
|
||||
}
|
||||
|
||||
// DataExport is the shape of the JSON the host downloads from
|
||||
// GET /me/data-export. We don't paginate or stream — for the scale
|
||||
// GuestGuard hosts have, a single response is reasonable. If a host
|
||||
// ever has 100k+ access logs we'll switch to async + email-a-link.
|
||||
type DataExport struct {
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Format string `json:"format"`
|
||||
User *domain.User `json:"user"`
|
||||
Events []*domain.Event `json:"events"`
|
||||
Guests []*domain.Guest `json:"guests"`
|
||||
Tokens []exportedToken `json:"tokens"`
|
||||
RSVPs []exportedRSVP `json:"rsvps"`
|
||||
AccessLogs []exportedAccess `json:"access_logs"`
|
||||
Notifs []exportedNotif `json:"notifications"`
|
||||
}
|
||||
|
||||
type exportedToken struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type exportedRSVP struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
Response string `json:"response"`
|
||||
PlusOnes int `json:"plus_ones"`
|
||||
SubmittedAt time.Time `json:"submitted_at"`
|
||||
}
|
||||
type exportedAccess struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
RiskScore *int `json:"risk_score,omitempty"`
|
||||
Flagged bool `json:"flagged"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type exportedNotif struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
Channel string `json:"channel"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// GET /me/data-export — returns every record the system holds about the
|
||||
// authenticated user. The Content-Disposition header makes browsers
|
||||
// offer a download rather than rendering inline.
|
||||
func (h *privacyHandler) dataExport(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.users.GetByID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
|
||||
export := DataExport{
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Format: "guestguard.v1",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Events the user hosts.
|
||||
events, err := h.events.List(r.Context(), hostID, 1000, 0)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load events")
|
||||
return
|
||||
}
|
||||
export.Events = events
|
||||
|
||||
// For each event, pull guests + tokens + rsvps + access_logs + notifications.
|
||||
for _, ev := range events {
|
||||
guests, err := h.guests.ListByEvent(r.Context(), ev.ID, 5000, 0)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guests")
|
||||
return
|
||||
}
|
||||
export.Guests = append(export.Guests, guests...)
|
||||
|
||||
for _, g := range guests {
|
||||
// Token (at most one per guest, but query as a list for symmetry).
|
||||
if err := h.appendTokens(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: tokens", "err", err)
|
||||
}
|
||||
if err := h.appendRSVPs(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: rsvps", "err", err)
|
||||
}
|
||||
if err := h.appendAccess(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: access", "err", err)
|
||||
}
|
||||
if err := h.appendNotifs(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: notifs", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-data-export.json"`)
|
||||
writeJSON(w, http.StatusOK, export)
|
||||
}
|
||||
|
||||
func (h *privacyHandler) appendTokens(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, expires_at, status, created_at
|
||||
FROM tokens WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var t exportedToken
|
||||
if err := rows.Scan(&t.ID, &t.GuestID, &t.ExpiresAt, &t.Status, &t.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Tokens = append(out.Tokens, t)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendRSVPs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, response::text, plus_ones, submitted_at
|
||||
FROM rsvps WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var r exportedRSVP
|
||||
if err := rows.Scan(&r.ID, &r.GuestID, &r.Response, &r.PlusOnes, &r.SubmittedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.RSVPs = append(out.RSVPs, r)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendAccess(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, risk_score, flagged, created_at
|
||||
FROM access_logs WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var a exportedAccess
|
||||
var rs *int
|
||||
if err := rows.Scan(&a.ID, &a.GuestID, &rs, &a.Flagged, &a.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
a.RiskScore = rs
|
||||
out.AccessLogs = append(out.AccessLogs, a)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendNotifs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, channel::text, type::text, status::text, created_at
|
||||
FROM notifications WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var n exportedNotif
|
||||
if err := rows.Scan(&n.ID, &n.GuestID, &n.Channel, &n.Type, &n.Status, &n.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Notifs = append(out.Notifs, n)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// DELETE /me — soft-deletes the host's account. All sessions are
|
||||
// revoked immediately. A hard delete happens via a separate cron 30
|
||||
// days later (TBD ops work). The user is logged out from all devices
|
||||
// as a side effect of revoking the refresh tokens.
|
||||
func (h *privacyHandler) deleteMe(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.users.SoftDelete(r.Context(), hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
writeError(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete account")
|
||||
return
|
||||
}
|
||||
// Best-effort: revoke refresh tokens so other sessions log out too.
|
||||
// Failure here is logged but doesn't roll back the soft-delete — the
|
||||
// access tokens (JWT) will still expire on their own ~15 minute TTL.
|
||||
if err := h.refresh.RevokeAllForUser(r.Context(), hostID); err != nil {
|
||||
h.logger.Warn("delete-me: revoke refresh tokens", "err", err, "user_id", hostID)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /me/accept-terms — records that the authenticated user accepts
|
||||
// the current ToS + privacy policy. Idempotent. Used by both the
|
||||
// onboarding gate (existing accounts created before T&C were enforced)
|
||||
// and any future "we updated our terms" re-acceptance flow.
|
||||
func (h *privacyHandler) acceptTerms(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.users.AcceptTerms(r.Context(), hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
writeError(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to record acceptance")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "accepted"})
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ipKey is the rate-limit key for endpoints scoped by source IP only
|
||||
// (e.g. POST /auth/signup). XFF/X-Real-IP are honoured because in the
|
||||
// homelab the API sits behind Traefik.
|
||||
func ipKey(r *http.Request) string {
|
||||
return clientIP(r)
|
||||
}
|
||||
|
||||
// pathKey returns a path-parameter as the rate-limit key — used for the
|
||||
// token-scoped endpoints so an attacker brute-forcing a single token is
|
||||
// limited regardless of the IPs they rotate through.
|
||||
func pathKey(name string) KeyFunc {
|
||||
return func(r *http.Request) string {
|
||||
return r.PathValue(name)
|
||||
}
|
||||
}
|
||||
|
||||
// userIDKey extracts the authenticated user id from the request context.
|
||||
// Returns "" when the route isn't behind requireAuth, in which case the
|
||||
// middleware bypasses (fail-open) — the route's own auth layer handles
|
||||
// rejection.
|
||||
func userIDKey(r *http.Request) string {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return uid.String()
|
||||
}
|
||||
|
||||
// KeyFunc mirrors ratelimit.KeyFunc so call sites don't have to import the
|
||||
// inner package.
|
||||
type KeyFunc = func(r *http.Request) string
|
||||
+264
-43
@@ -5,63 +5,167 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
"github.com/alchemistkay/guestguard/internal/ratelimit"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
logger *slog.Logger
|
||||
db *storage.DB
|
||||
hub *Hub
|
||||
users *userHandler
|
||||
events *eventHandler
|
||||
guests *guestHandler
|
||||
tokens *tokenHandler
|
||||
rsvps *rsvpHandler
|
||||
activity *activityHandler
|
||||
ws *wsHandler
|
||||
health *healthHandler
|
||||
logger *slog.Logger
|
||||
db *storage.DB
|
||||
hub *Hub
|
||||
authH *authHandler
|
||||
me *meHandler
|
||||
events *eventHandler
|
||||
guests *guestHandler
|
||||
tokens *tokenHandler
|
||||
rsvps *rsvpHandler
|
||||
activity *activityHandler
|
||||
ws *wsHandler
|
||||
wsTicket *wsTicketHandler
|
||||
health *healthHandler
|
||||
signer *auth.JWTSigner
|
||||
limiter *ratelimit.Limiter
|
||||
unsub *unsubscribeHandler
|
||||
webhooks *webhookHandler
|
||||
csv *csvImportHandler
|
||||
billing *billingHandler
|
||||
stripeWH *stripeWebhookHandler
|
||||
privacy *privacyHandler
|
||||
}
|
||||
|
||||
type ServerDeps struct {
|
||||
Logger *slog.Logger
|
||||
DB *storage.DB
|
||||
Hub *Hub
|
||||
AccessPublisher accessPublisher
|
||||
RSVPPublisher rsvpPublisher
|
||||
FraudScorer fraudScorer
|
||||
TokenTTL time.Duration
|
||||
AccessPublisher accessPublisher
|
||||
RSVPPublisher rsvpPublisher
|
||||
InvitationPublisher invitationPublisher
|
||||
FraudScorer fraudScorer
|
||||
TokenTTL time.Duration
|
||||
|
||||
// Auth
|
||||
JWTSecret string
|
||||
JWTIssuer string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
PublicBaseURL string
|
||||
RefreshCookieDomain string
|
||||
RefreshCookieSecure bool
|
||||
EmailSender auth.EmailSender
|
||||
WSTicketTTL time.Duration
|
||||
|
||||
// Rate limiting / abuse controls
|
||||
Redis *redis.Client
|
||||
LoginLockoutMax int // failed attempts before account lockout (default 5)
|
||||
LoginFailWindow time.Duration // counter TTL (default 15 min)
|
||||
|
||||
// Notifications / unsubscribe
|
||||
NotificationRepo *notification.Repo
|
||||
SuppressionRepo *notification.SuppressionRepo
|
||||
UnsubscribeSigner *notification.UnsubscribeSigner
|
||||
|
||||
// Billing (Block F). Nil StripeClient leaves billing disabled — the
|
||||
// system still boots and runs, all users sit on the free tier with
|
||||
// its limits enforced; /billing/* returns 503.
|
||||
StripeClient *billing.Client
|
||||
}
|
||||
|
||||
func NewServer(deps ServerDeps) *Server {
|
||||
func NewServer(deps ServerDeps) (*Server, error) {
|
||||
eventRepo := storage.NewEventRepo(deps.DB)
|
||||
guestRepo := storage.NewGuestRepo(deps.DB)
|
||||
tokenRepo := storage.NewTokenRepo(deps.DB)
|
||||
rsvpRepo := storage.NewRSVPRepo(deps.DB)
|
||||
accessRepo := storage.NewAccessLogRepo(deps.DB)
|
||||
userRepo := storage.NewUserRepo(deps.DB)
|
||||
verifRepo := storage.NewEmailVerificationRepo(deps.DB)
|
||||
resetRepo := storage.NewPasswordResetRepo(deps.DB)
|
||||
refreshRepo := storage.NewRefreshTokenRepo(deps.DB)
|
||||
subRepo := storage.NewSubscriptionRepo(deps.DB)
|
||||
enforcer := newTierEnforcer(subRepo, deps.PublicBaseURL)
|
||||
|
||||
signer, err := auth.NewJWTSigner(deps.JWTSecret, deps.AccessTokenTTL, deps.JWTIssuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hasher := auth.NewPasswordHasher()
|
||||
|
||||
emails := deps.EmailSender
|
||||
if emails == nil {
|
||||
emails = auth.LogEmailSender{Logger: deps.Logger}
|
||||
}
|
||||
|
||||
hub := deps.Hub
|
||||
if hub == nil {
|
||||
hub = NewHub(deps.Logger)
|
||||
}
|
||||
|
||||
wsTicketTTL := deps.WSTicketTTL
|
||||
if wsTicketTTL <= 0 {
|
||||
wsTicketTTL = 60 * time.Second
|
||||
}
|
||||
wsTickets := newWSTicketStore(wsTicketTTL)
|
||||
|
||||
var limiter *ratelimit.Limiter
|
||||
var lockout *auth.LockoutTracker
|
||||
if deps.Redis != nil {
|
||||
limiter = ratelimit.New(deps.Redis, "gg:rl")
|
||||
lockoutMax := deps.LoginLockoutMax
|
||||
if lockoutMax <= 0 {
|
||||
lockoutMax = 5
|
||||
}
|
||||
failWindow := deps.LoginFailWindow
|
||||
if failWindow <= 0 {
|
||||
failWindow = 15 * time.Minute
|
||||
}
|
||||
lockout = auth.NewLockoutTracker(deps.Redis, lockoutMax, failWindow)
|
||||
}
|
||||
|
||||
authH := newAuthHandler(authHandlerDeps{
|
||||
Logger: deps.Logger,
|
||||
Users: userRepo,
|
||||
Verifications: verifRepo,
|
||||
Resets: resetRepo,
|
||||
Refreshes: refreshRepo,
|
||||
Hasher: hasher,
|
||||
Signer: signer,
|
||||
Emails: emails,
|
||||
Lockout: lockout,
|
||||
Limiter: limiter,
|
||||
PublicBaseURL: deps.PublicBaseURL,
|
||||
EmailVerificationTTL: deps.EmailVerificationTTL,
|
||||
PasswordResetTTL: deps.PasswordResetTTL,
|
||||
RefreshTTL: deps.RefreshTokenTTL,
|
||||
CookieDomain: deps.RefreshCookieDomain,
|
||||
CookieSecure: deps.RefreshCookieSecure,
|
||||
})
|
||||
|
||||
return &Server{
|
||||
logger: deps.Logger,
|
||||
db: deps.DB,
|
||||
hub: hub,
|
||||
users: &userHandler{repo: userRepo},
|
||||
events: &eventHandler{repo: eventRepo},
|
||||
guests: &guestHandler{guests: guestRepo, events: eventRepo},
|
||||
authH: authH,
|
||||
me: &meHandler{users: userRepo},
|
||||
events: &eventHandler{repo: eventRepo, enforcer: enforcer},
|
||||
guests: &guestHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
|
||||
tokens: &tokenHandler{
|
||||
logger: deps.Logger,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
events: eventRepo,
|
||||
accessLogs: accessRepo,
|
||||
gen: auth.NewGenerator(),
|
||||
ttl: deps.TokenTTL,
|
||||
pub: deps.AccessPublisher,
|
||||
logger: deps.Logger,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
events: eventRepo,
|
||||
users: userRepo,
|
||||
accessLogs: accessRepo,
|
||||
gen: auth.NewGenerator(),
|
||||
ttl: deps.TokenTTL,
|
||||
pub: deps.AccessPublisher,
|
||||
invitations: deps.InvitationPublisher,
|
||||
publicBaseURL: deps.PublicBaseURL,
|
||||
},
|
||||
rsvps: &rsvpHandler{
|
||||
logger: deps.Logger,
|
||||
@@ -78,9 +182,46 @@ func NewServer(deps ServerDeps) *Server {
|
||||
rsvps: rsvpRepo,
|
||||
accessLogs: accessRepo,
|
||||
},
|
||||
ws: &wsHandler{logger: deps.Logger, hub: hub},
|
||||
health: &healthHandler{pool: deps.DB.Pool},
|
||||
}
|
||||
ws: &wsHandler{logger: deps.Logger, hub: hub, tickets: wsTickets},
|
||||
wsTicket: &wsTicketHandler{tickets: wsTickets, events: eventRepo},
|
||||
health: &healthHandler{pool: deps.DB.Pool},
|
||||
signer: signer,
|
||||
limiter: limiter,
|
||||
unsub: &unsubscribeHandler{
|
||||
logger: deps.Logger,
|
||||
signer: deps.UnsubscribeSigner,
|
||||
suppress: deps.SuppressionRepo,
|
||||
},
|
||||
webhooks: &webhookHandler{
|
||||
logger: deps.Logger,
|
||||
notifs: deps.NotificationRepo,
|
||||
suppress: deps.SuppressionRepo,
|
||||
},
|
||||
csv: &csvImportHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
|
||||
billing: &billingHandler{
|
||||
logger: deps.Logger,
|
||||
stripe: deps.StripeClient,
|
||||
users: userRepo,
|
||||
subscriptions: subRepo,
|
||||
publicBaseURL: deps.PublicBaseURL,
|
||||
},
|
||||
stripeWH: &stripeWebhookHandler{
|
||||
logger: deps.Logger,
|
||||
stripe: deps.StripeClient,
|
||||
subs: subRepo,
|
||||
},
|
||||
privacy: &privacyHandler{
|
||||
logger: deps.Logger,
|
||||
users: userRepo,
|
||||
events: eventRepo,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
rsvps: rsvpRepo,
|
||||
access: accessRepo,
|
||||
notifs: deps.DB,
|
||||
refresh: refreshRepo,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Hub() *Hub { return s.hub }
|
||||
@@ -91,25 +232,104 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.HandleFunc("GET /health", s.health.live)
|
||||
mux.HandleFunc("GET /health/ready", s.health.ready)
|
||||
|
||||
mux.HandleFunc("POST /users", s.users.upsert)
|
||||
// Per-route rate limiters (no-op when Redis isn't wired).
|
||||
authed := requireAuth(s.signer)
|
||||
rl := func(name string, limit int, window time.Duration, keyFn KeyFunc, h http.Handler) http.Handler {
|
||||
if s.limiter == nil {
|
||||
return h
|
||||
}
|
||||
return s.limiter.Middleware(
|
||||
ratelimit.Rule{Name: name, Limit: limit, Window: window},
|
||||
keyFn,
|
||||
s.logger,
|
||||
)(h)
|
||||
}
|
||||
|
||||
mux.HandleFunc("POST /events", s.events.create)
|
||||
mux.HandleFunc("GET /events", s.events.list)
|
||||
mux.HandleFunc("GET /events/{id}", s.events.get)
|
||||
mux.HandleFunc("PATCH /events/{id}", s.events.update)
|
||||
mux.HandleFunc("DELETE /events/{id}", s.events.delete)
|
||||
// Anonymous auth endpoints — POST /auth/login + /auth/forgot-password
|
||||
// rate-limit inside the handler (key includes the email body field).
|
||||
mux.Handle("POST /auth/signup",
|
||||
rl("auth_signup", 5, time.Hour, ipKey, http.HandlerFunc(s.authH.signup)))
|
||||
mux.HandleFunc("POST /auth/login", s.authH.login)
|
||||
mux.HandleFunc("POST /auth/refresh", s.authH.refresh)
|
||||
mux.HandleFunc("POST /auth/logout", s.authH.logout)
|
||||
mux.HandleFunc("POST /auth/verify-email", s.authH.verifyEmail)
|
||||
mux.HandleFunc("POST /auth/forgot-password", s.authH.forgotPassword)
|
||||
mux.HandleFunc("POST /auth/reset-password", s.authH.resetPassword)
|
||||
|
||||
mux.HandleFunc("POST /events/{id}/guests", s.guests.create)
|
||||
mux.HandleFunc("GET /events/{id}/guests", s.guests.list)
|
||||
mux.Handle("GET /me", authed(http.HandlerFunc(s.me.get)))
|
||||
mux.Handle("POST /auth/ws-ticket", authed(http.HandlerFunc(s.wsTicket.issue)))
|
||||
|
||||
mux.HandleFunc("GET /events/{id}/activity", s.activity.list)
|
||||
// Privacy / GDPR-style endpoints — host can export their data,
|
||||
// delete their account, and record terms acceptance from the
|
||||
// onboarding gate.
|
||||
mux.Handle("GET /me/data-export", authed(http.HandlerFunc(s.privacy.dataExport)))
|
||||
mux.Handle("DELETE /me", authed(http.HandlerFunc(s.privacy.deleteMe)))
|
||||
mux.Handle("POST /me/accept-terms", authed(http.HandlerFunc(s.privacy.acceptTerms)))
|
||||
|
||||
mux.HandleFunc("POST /events/{id}/guests/{guest_id}/tokens", s.tokens.issue)
|
||||
mux.HandleFunc("GET /access/{token}", s.tokens.access)
|
||||
mux.HandleFunc("POST /rsvp/{token}", s.rsvps.submit)
|
||||
// Host-facing event/guest/token writes are limited by user_id.
|
||||
mux.Handle("POST /events",
|
||||
authed(rl("events_create", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.events.create))))
|
||||
mux.Handle("GET /events", authed(http.HandlerFunc(s.events.list)))
|
||||
mux.Handle("GET /events/{id}", authed(http.HandlerFunc(s.events.get)))
|
||||
mux.Handle("PATCH /events/{id}", authed(http.HandlerFunc(s.events.update)))
|
||||
mux.Handle("DELETE /events/{id}", authed(http.HandlerFunc(s.events.delete)))
|
||||
|
||||
mux.Handle("POST /events/{id}/guests",
|
||||
authed(rl("guests_create", 1000, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.create))))
|
||||
mux.Handle("GET /events/{id}/guests", authed(http.HandlerFunc(s.guests.list)))
|
||||
mux.Handle("PATCH /events/{id}/guests/{guest_id}",
|
||||
authed(rl("guests_update", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.update))))
|
||||
mux.Handle("DELETE /events/{id}/guests/{guest_id}",
|
||||
authed(rl("guests_delete", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.delete))))
|
||||
|
||||
// CSV import (Block E). Preview is cheap (no DB writes), so we keep
|
||||
// its budget separate from commit's daily-row-add limit.
|
||||
mux.Handle("POST /events/{id}/guests/import/preview",
|
||||
authed(rl("guests_import_preview", 30, time.Hour, userIDKey, http.HandlerFunc(s.csv.preview))))
|
||||
mux.Handle("POST /events/{id}/guests/import",
|
||||
authed(rl("guests_import_commit", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.csv.commit))))
|
||||
mux.Handle("GET /events/{id}/guests/import/template", authed(http.HandlerFunc(s.csv.template)))
|
||||
|
||||
mux.Handle("GET /events/{id}/activity", authed(http.HandlerFunc(s.activity.list)))
|
||||
|
||||
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens",
|
||||
authed(rl("tokens_issue", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.issue))))
|
||||
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens/rotate",
|
||||
authed(rl("tokens_rotate", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.rotate))))
|
||||
mux.Handle("POST /events/{id}/guests/invitations/bulk",
|
||||
authed(rl("tokens_bulk", 10, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.bulkIssue))))
|
||||
|
||||
// Guest-facing endpoints — rate-limited by the access token in the URL
|
||||
// path so an attacker hammering a single invitation is slowed regardless
|
||||
// of their source IP.
|
||||
mux.Handle("GET /access/{token}",
|
||||
rl("access", 60, time.Hour, pathKey("token"), http.HandlerFunc(s.tokens.access)))
|
||||
mux.Handle("POST /rsvp/{token}",
|
||||
rl("rsvp", 10, time.Hour, pathKey("token"), http.HandlerFunc(s.rsvps.submit)))
|
||||
|
||||
// WebSocket endpoint authenticates via single-use ticket on the query
|
||||
// string (see POST /auth/ws-ticket).
|
||||
mux.HandleFunc("GET /ws/events/{id}", s.ws.handle)
|
||||
|
||||
// Unsubscribe (signed token, no auth required — links live in emails).
|
||||
mux.HandleFunc("GET /unsubscribe/{token}", s.unsub.preview)
|
||||
mux.HandleFunc("POST /unsubscribe/{token}", s.unsub.confirm)
|
||||
|
||||
// Provider webhooks. Signature verification is enforced in the handler
|
||||
// once GG_TWILIO_AUTH_TOKEN / GG_SES_WEBHOOK_SECRET are set.
|
||||
mux.HandleFunc("POST /webhooks/twilio/status", s.webhooks.twilio)
|
||||
mux.HandleFunc("POST /webhooks/ses/notifications", s.webhooks.ses)
|
||||
|
||||
// Billing (Block F). /billing/status is safe for everyone — returns
|
||||
// free tier defaults when Stripe is unconfigured or the user has no
|
||||
// subscription, so the frontend's plan page always has something to
|
||||
// render. The action endpoints (checkout, portal) return 503 in dev
|
||||
// without Stripe credentials.
|
||||
mux.Handle("GET /billing/status", authed(http.HandlerFunc(s.billing.status)))
|
||||
mux.Handle("POST /billing/checkout-session", authed(http.HandlerFunc(s.billing.checkoutSession)))
|
||||
mux.Handle("POST /billing/portal", authed(http.HandlerFunc(s.billing.portalSession)))
|
||||
mux.HandleFunc("POST /webhooks/stripe", s.stripeWH.handle)
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusNotFound, "not found")
|
||||
})
|
||||
@@ -128,9 +348,10 @@ func corsMiddleware(next http.Handler) http.Handler {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Device-Fingerprint")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Device-Fingerprint")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/stripe/stripe-go/v82"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// stripeWebhookHandler accepts and verifies Stripe events, then
|
||||
// projects subscription lifecycle changes onto the subscriptions table.
|
||||
// We track only what middleware needs to decide access — tier + status +
|
||||
// period bounds. Invoice events (payment failed / succeeded) are logged
|
||||
// for observability; dunning automation lands in Block F3.
|
||||
type stripeWebhookHandler struct {
|
||||
logger *slog.Logger
|
||||
stripe *billing.Client
|
||||
subs *storage.SubscriptionRepo
|
||||
}
|
||||
|
||||
// POST /webhooks/stripe — signature-verified Stripe event sink.
|
||||
func (h *stripeWebhookHandler) handle(w http.ResponseWriter, r *http.Request) {
|
||||
if h.stripe == nil || !h.stripe.Enabled() {
|
||||
// Not configured on this instance — reject so a misrouted event
|
||||
// isn't silently swallowed. Stripe will retry which is harmless.
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "read body")
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
event, err := h.stripe.VerifyWebhook(body, r.Header.Get("Stripe-Signature"))
|
||||
if err != nil {
|
||||
h.logger.Warn("stripe webhook signature failed", "err", err)
|
||||
writeError(w, http.StatusBadRequest, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "customer.subscription.created", "customer.subscription.updated":
|
||||
h.applySubscription(r, event)
|
||||
case "customer.subscription.deleted":
|
||||
h.applySubscriptionDeleted(r, event)
|
||||
case "invoice.payment_succeeded":
|
||||
// Clear past_due if Stripe says payment caught up. Most flows are
|
||||
// already covered by the subscription.updated event Stripe also
|
||||
// fires — this is belt-and-braces.
|
||||
h.applySubscription(r, event)
|
||||
case "invoice.payment_failed":
|
||||
h.logger.Warn("stripe invoice payment failed", "event_id", event.ID)
|
||||
// Subscription.status will flip to past_due via the
|
||||
// subscription.updated event Stripe fires alongside.
|
||||
default:
|
||||
h.logger.Debug("stripe event ignored", "type", event.Type)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// applySubscription patches the subscriptions row keyed by Stripe
|
||||
// customer id. Best-effort — failures here are logged but don't NACK
|
||||
// the webhook (Stripe would retry forever and the row would never
|
||||
// converge).
|
||||
func (h *stripeWebhookHandler) applySubscription(r *http.Request, event stripe.Event) {
|
||||
var sub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
// Some invoice events carry an Invoice payload — try to extract
|
||||
// the subscription id from there and short-circuit on status.
|
||||
h.logger.Debug("stripe webhook: not a subscription payload", "type", event.Type, "err", err)
|
||||
return
|
||||
}
|
||||
if sub.Customer == nil || sub.Customer.ID == "" {
|
||||
h.logger.Warn("stripe webhook: subscription has no customer", "subscription", sub.ID)
|
||||
return
|
||||
}
|
||||
|
||||
tier := tierFromSubscription(&sub)
|
||||
status := string(sub.Status)
|
||||
cancelAtPeriodEnd := sub.CancelAtPeriodEnd
|
||||
|
||||
// As of API 2024-10-28, current_period_end lives on the subscription
|
||||
// item, not the subscription. We pick the earliest item's end — for
|
||||
// single-item subscriptions (our case) that's the canonical one.
|
||||
var periodEnd *time.Time
|
||||
for _, item := range sub.Items.Data {
|
||||
if item.CurrentPeriodEnd > 0 {
|
||||
t := time.Unix(item.CurrentPeriodEnd, 0).UTC()
|
||||
if periodEnd == nil || t.Before(*periodEnd) {
|
||||
periodEnd = &t
|
||||
}
|
||||
}
|
||||
}
|
||||
subID := sub.ID
|
||||
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
|
||||
StripeSubscriptionID: &subID,
|
||||
Tier: stringPtr(string(tier)),
|
||||
Status: &status,
|
||||
CurrentPeriodEnd: periodEnd,
|
||||
CancelAtPeriodEnd: &cancelAtPeriodEnd,
|
||||
}); err != nil {
|
||||
h.logger.Error("stripe webhook: update subscription failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *stripeWebhookHandler) applySubscriptionDeleted(r *http.Request, event stripe.Event) {
|
||||
var sub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
h.logger.Warn("stripe webhook: bad deleted payload", "err", err)
|
||||
return
|
||||
}
|
||||
if sub.Customer == nil {
|
||||
return
|
||||
}
|
||||
status := "canceled"
|
||||
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
|
||||
Status: &status,
|
||||
}); err != nil {
|
||||
h.logger.Error("stripe webhook: mark canceled failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tierFromSubscription inspects the Stripe price metadata to figure out
|
||||
// which GuestGuard tier this subscription corresponds to. We read a
|
||||
// price-level metadata key `gg_tier` (set in the Stripe dashboard when
|
||||
// you create the Price). Fallback: free.
|
||||
func tierFromSubscription(sub *stripe.Subscription) billing.Tier {
|
||||
if sub == nil || len(sub.Items.Data) == 0 {
|
||||
return billing.TierFree
|
||||
}
|
||||
for _, item := range sub.Items.Data {
|
||||
if item.Price == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := item.Price.Metadata["gg_tier"]; ok {
|
||||
t := billing.Tier(v)
|
||||
if t.Valid() {
|
||||
return t
|
||||
}
|
||||
}
|
||||
// Heuristic fallback for tests / unconfigured prices: look at the
|
||||
// recurring interval and amount tier.
|
||||
if item.Price.Recurring != nil && item.Price.UnitAmount >= 19900 {
|
||||
return billing.TierBusiness
|
||||
}
|
||||
if item.Price.Recurring != nil && item.Price.UnitAmount >= 4900 {
|
||||
return billing.TierPro
|
||||
}
|
||||
}
|
||||
return billing.TierFree
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string { return &s }
|
||||
+329
-14
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -20,25 +21,38 @@ type accessPublisher interface {
|
||||
PublishAccessAttempted(ctx context.Context, evt natspub.AccessAttempted) error
|
||||
}
|
||||
|
||||
type invitationPublisher interface {
|
||||
PublishInvitationSend(ctx context.Context, evt natspub.InvitationSend) error
|
||||
}
|
||||
|
||||
type tokenHandler struct {
|
||||
logger *slog.Logger
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
events *storage.EventRepo
|
||||
accessLogs *storage.AccessLogRepo
|
||||
gen *auth.Generator
|
||||
ttl time.Duration
|
||||
pub accessPublisher
|
||||
logger *slog.Logger
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
events *storage.EventRepo
|
||||
users *storage.UserRepo
|
||||
accessLogs *storage.AccessLogRepo
|
||||
gen *auth.Generator
|
||||
ttl time.Duration
|
||||
pub accessPublisher
|
||||
invitations invitationPublisher
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
type issueTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
Meta *domain.Token `json:"meta"`
|
||||
Token string `json:"token"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
Meta *domain.Token `json:"meta"`
|
||||
InvitationQueued bool `json:"invitation_queued"`
|
||||
InvitationLink string `json:"invitation_link"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/{guest_id}/tokens — issue a token for the guest.
|
||||
func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
@@ -47,6 +61,10 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
guest, err := h.guests.Get(r.Context(), guestID)
|
||||
if err != nil {
|
||||
@@ -78,10 +96,307 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
link := h.invitationLink(raw)
|
||||
invitationQueued := h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
|
||||
|
||||
writeJSON(w, http.StatusCreated, issueTokenResponse{
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
InvitationQueued: invitationQueued,
|
||||
InvitationLink: link,
|
||||
})
|
||||
}
|
||||
|
||||
// queueInvitation publishes an invitation.send event so the notifier can
|
||||
// dispatch a branded email. Best-effort: if any step fails we log and
|
||||
// return false rather than failing the whole token-issue request — the
|
||||
// host still has the raw URL in the response and can re-trigger sending.
|
||||
func (h *tokenHandler) queueInvitation(
|
||||
ctx context.Context,
|
||||
event *domain.Event,
|
||||
guest *domain.Guest,
|
||||
tk *domain.Token,
|
||||
hostID uuid.UUID,
|
||||
rawToken string,
|
||||
) bool {
|
||||
if h.invitations == nil {
|
||||
return false
|
||||
}
|
||||
if guest.Email == nil || *guest.Email == "" {
|
||||
// Phone-only / nameless guests get no email — host shares the link
|
||||
// manually. Show that on the UI so it's not a silent surprise.
|
||||
return false
|
||||
}
|
||||
hostName := ""
|
||||
if h.users != nil {
|
||||
if host, err := h.users.GetByID(ctx, hostID); err == nil && host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
}
|
||||
evt := natspub.InvitationSend{
|
||||
EventID: event.ID,
|
||||
GuestID: guest.ID,
|
||||
TokenID: tk.ID,
|
||||
GuestName: guest.Name,
|
||||
GuestEmail: *guest.Email,
|
||||
HostName: hostName,
|
||||
EventName: event.Name,
|
||||
Venue: event.Venue,
|
||||
EventDate: event.EventDate,
|
||||
Link: h.invitationLink(rawToken),
|
||||
IssuedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := h.invitations.PublishInvitationSend(ctx, evt); err != nil {
|
||||
h.logger.Warn("publish invitation.send (continuing)", "err", err, "guest_id", guest.ID)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// invitationLink renders the public RSVP URL the guest clicks from their
|
||||
// inbox. publicBaseURL is the externally-reachable host (set via
|
||||
// GG_PUBLIC_BASE_URL); access via /rsvp/<token> is intentional — the
|
||||
// frontend page rsvp/[token].vue catches the raw token.
|
||||
func (h *tokenHandler) invitationLink(raw string) string {
|
||||
base := h.publicBaseURL
|
||||
if base == "" {
|
||||
base = "http://localhost:3000"
|
||||
}
|
||||
return base + "/rsvp/" + raw
|
||||
}
|
||||
|
||||
// bulkIssueRequest is the optional JSON body for the bulk-invite call.
|
||||
// An empty body (or missing GuestIDs) means "every guest on the event
|
||||
// who doesn't already have a token".
|
||||
type bulkIssueRequest struct {
|
||||
GuestIDs []string `json:"guest_ids"`
|
||||
}
|
||||
|
||||
type bulkIssueItemError struct {
|
||||
GuestID string `json:"guest_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// bulkIssueToken is one minted invitation. The raw token is returned so
|
||||
// the host's UI can offer a "copy link" affordance after a bulk send
|
||||
// (especially for guests with no email on file) without making another
|
||||
// round-trip. Same data the per-guest issue endpoint already exposes,
|
||||
// scoped to the host who owns the event.
|
||||
type bulkIssueToken struct {
|
||||
GuestID string `json:"guest_id"`
|
||||
Token string `json:"token"`
|
||||
InvitationQueued bool `json:"invitation_queued"`
|
||||
InvitationLink string `json:"invitation_link"`
|
||||
}
|
||||
|
||||
type bulkIssueResponse struct {
|
||||
Issued int `json:"issued"`
|
||||
Queued int `json:"queued"`
|
||||
SkippedExisting int `json:"skipped_existing"`
|
||||
SkippedNoEmail int `json:"skipped_no_email"`
|
||||
Tokens []bulkIssueToken `json:"tokens,omitempty"`
|
||||
Errors []bulkIssueItemError `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/invitations/bulk — generate tokens for every
|
||||
// eligible guest (or the explicit subset) on the event and queue an
|
||||
// invitation email for those with an address. Best-effort: any per-guest
|
||||
// error is reported in the response and doesn't abort the rest.
|
||||
func (h *tokenHandler) bulkIssue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req bulkIssueRequest
|
||||
if r.ContentLength > 0 {
|
||||
// Body is optional; only decode when something was actually sent.
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var onlyIDs []uuid.UUID
|
||||
if len(req.GuestIDs) > 0 {
|
||||
onlyIDs = make([]uuid.UUID, 0, len(req.GuestIDs))
|
||||
for _, raw := range req.GuestIDs {
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid guest id: "+raw)
|
||||
return
|
||||
}
|
||||
onlyIDs = append(onlyIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
guests, err := h.guests.ListGuestsForInvitation(r.Context(), eventID, onlyIDs)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guests")
|
||||
return
|
||||
}
|
||||
|
||||
hostName := ""
|
||||
if h.users != nil {
|
||||
if host, err := h.users.GetByID(r.Context(), hostID); err == nil && host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
}
|
||||
|
||||
resp := bulkIssueResponse{}
|
||||
for _, g := range guests {
|
||||
if g.HasToken {
|
||||
resp.SkippedExisting++
|
||||
continue
|
||||
}
|
||||
raw, hash, err := h.gen.Generate()
|
||||
if err != nil {
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "mint token failed"})
|
||||
continue
|
||||
}
|
||||
tk, err := h.tokens.Create(r.Context(), storage.CreateTokenParams{
|
||||
GuestID: g.ID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(h.ttl),
|
||||
})
|
||||
if err != nil {
|
||||
// Likely a race against the unique constraint (someone else
|
||||
// issued in parallel) — surface but don't fail the batch.
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: err.Error()})
|
||||
continue
|
||||
}
|
||||
resp.Issued++
|
||||
link := h.invitationLink(raw)
|
||||
tokenInfo := bulkIssueToken{
|
||||
GuestID: g.ID.String(),
|
||||
Token: raw,
|
||||
InvitationLink: link,
|
||||
}
|
||||
|
||||
if g.Email == "" {
|
||||
resp.SkippedNoEmail++
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
evt := natspub.InvitationSend{
|
||||
EventID: event.ID,
|
||||
GuestID: g.ID,
|
||||
TokenID: tk.ID,
|
||||
GuestName: g.Name,
|
||||
GuestEmail: g.Email,
|
||||
HostName: hostName,
|
||||
EventName: event.Name,
|
||||
Venue: event.Venue,
|
||||
EventDate: event.EventDate,
|
||||
Link: h.invitationLink(raw),
|
||||
IssuedAt: time.Now().UTC(),
|
||||
}
|
||||
if h.invitations == nil {
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
if err := h.invitations.PublishInvitationSend(r.Context(), evt); err != nil {
|
||||
h.logger.Warn("publish invitation.send (bulk, continuing)", "err", err, "guest_id", g.ID)
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "publish failed"})
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
resp.Queued++
|
||||
tokenInfo.InvitationQueued = true
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type rotateTokenRequest struct {
|
||||
// SendEmail asks the notifier to re-deliver the invitation. False
|
||||
// means "just give me a fresh link" — typical for phone-only guests
|
||||
// where the host shares the new URL via SMS.
|
||||
SendEmail bool `json:"send_email"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/{guest_id}/tokens/rotate — invalidate the
|
||||
// guest's existing invitation link and mint a fresh one. Optionally
|
||||
// re-publishes invitation.send so the notifier re-delivers via email.
|
||||
// The old URL stops working immediately.
|
||||
func (h *tokenHandler) rotate(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
guest, err := h.guests.Get(r.Context(), guestID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guest")
|
||||
return
|
||||
}
|
||||
if guest.EventID != eventID {
|
||||
writeError(w, http.StatusNotFound, "guest not found in event")
|
||||
return
|
||||
}
|
||||
|
||||
var req rotateTokenRequest
|
||||
if r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
raw, hash, err := h.gen.Generate()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to generate token")
|
||||
return
|
||||
}
|
||||
tk, err := h.tokens.RotateForGuest(r.Context(), storage.CreateTokenParams{
|
||||
GuestID: guestID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(h.ttl),
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("rotate token", "err", err, "guest_id", guestID)
|
||||
writeError(w, http.StatusInternalServerError, "failed to rotate token")
|
||||
return
|
||||
}
|
||||
|
||||
link := h.invitationLink(raw)
|
||||
invitationQueued := false
|
||||
if req.SendEmail {
|
||||
invitationQueued = h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, issueTokenResponse{
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
InvitationQueued: invitationQueued,
|
||||
InvitationLink: link,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
)
|
||||
|
||||
type unsubscribeHandler struct {
|
||||
logger *slog.Logger
|
||||
signer *notification.UnsubscribeSigner
|
||||
suppress *notification.SuppressionRepo
|
||||
}
|
||||
|
||||
// GET /unsubscribe/{token} — surface the email address that the token
|
||||
// belongs to so the frontend can show a confirmation page. Honoured even
|
||||
// before the user clicks "Confirm" so they see what's being unsubscribed.
|
||||
func (h *unsubscribeHandler) preview(w http.ResponseWriter, r *http.Request) {
|
||||
if h.signer == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
|
||||
return
|
||||
}
|
||||
email, err := h.signer.Verify(r.PathValue("token"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"email": email})
|
||||
}
|
||||
|
||||
// POST /unsubscribe/{token} — add the email to the suppression list.
|
||||
// Idempotent: clicking the link twice keeps the existing entry.
|
||||
func (h *unsubscribeHandler) confirm(w http.ResponseWriter, r *http.Request) {
|
||||
if h.signer == nil || h.suppress == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
|
||||
return
|
||||
}
|
||||
email, err := h.signer.Verify(r.PathValue("token"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
|
||||
return
|
||||
}
|
||||
if err := h.suppress.Add(r.Context(), email, "user clicked unsubscribe", notification.SuppressionUser); err != nil {
|
||||
h.logger.Error("add suppression", "err", err, "email", email)
|
||||
writeError(w, http.StatusInternalServerError, "failed to unsubscribe")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "unsubscribed", "email": email})
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type userHandler struct {
|
||||
repo *storage.UserRepo
|
||||
}
|
||||
|
||||
type upsertUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// POST /users — idempotent: returns the existing user if the email already
|
||||
// exists, creates one otherwise. This keeps the demo flow simple without
|
||||
// requiring real auth.
|
||||
func (h *userHandler) upsert(w http.ResponseWriter, r *http.Request) {
|
||||
var req upsertUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "email is invalid")
|
||||
return
|
||||
}
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.repo.Create(r.Context(), req.Email, req.Name)
|
||||
if err == nil {
|
||||
writeJSON(w, http.StatusCreated, u)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, domain.ErrEmailTaken) {
|
||||
existing, getErr := h.repo.GetByEmail(r.Context(), req.Email)
|
||||
if getErr != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, existing)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
)
|
||||
|
||||
// webhookHandler accepts provider status notifications and reflects them
|
||||
// onto the notifications table + suppression list.
|
||||
//
|
||||
// Signature verification is intentionally a TODO until the user provisions
|
||||
// real Twilio + SES creds — verifying against test fixtures alone would
|
||||
// give a false sense of security. The endpoint is therefore *not* exposed
|
||||
// publicly until the deployment is ready.
|
||||
type webhookHandler struct {
|
||||
logger *slog.Logger
|
||||
notifs *notification.Repo
|
||||
suppress *notification.SuppressionRepo
|
||||
}
|
||||
|
||||
// POST /webhooks/twilio/status — Twilio status callback (form-encoded).
|
||||
// Fields we care about: MessageSid, MessageStatus (sent|delivered|
|
||||
// undelivered|failed), ErrorCode, To.
|
||||
func (h *webhookHandler) twilio(w http.ResponseWriter, r *http.Request) {
|
||||
if h.notifs == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// TODO(blockD2): verify X-Twilio-Signature with GG_TWILIO_AUTH_TOKEN.
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid form")
|
||||
return
|
||||
}
|
||||
sid := r.PostForm.Get("MessageSid")
|
||||
status := r.PostForm.Get("MessageStatus")
|
||||
if sid == "" || status == "" {
|
||||
writeError(w, http.StatusBadRequest, "missing MessageSid / MessageStatus")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
switch status {
|
||||
case "delivered":
|
||||
_ = h.notifs.MarkDelivered(ctx, sid)
|
||||
case "undelivered", "failed":
|
||||
_ = h.notifs.MarkBounce(ctx, sid, "permanent")
|
||||
}
|
||||
h.logger.Info("twilio status callback", "sid", sid, "status", status)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /webhooks/ses/notifications — SNS-delivered SES notification (JSON).
|
||||
// Handles the two shapes SES uses: bounce + complaint events. Each event
|
||||
// carries the messageId we stored in provider_message_id and an array of
|
||||
// affected recipients.
|
||||
func (h *webhookHandler) ses(w http.ResponseWriter, r *http.Request) {
|
||||
if h.notifs == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// TODO(blockD2): verify SNS signature using the cert URL field.
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "read body")
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"Type"` // "Notification" | "SubscriptionConfirmation"
|
||||
Message string `json:"Message"` // stringified JSON for Notification
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json envelope")
|
||||
return
|
||||
}
|
||||
if envelope.Type == "SubscriptionConfirmation" {
|
||||
// Confirmed by visiting SubscribeURL — manual op-side step.
|
||||
h.logger.Info("ses subscription confirmation received (manual confirm required)")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if envelope.Message == "" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
var inner struct {
|
||||
NotificationType string `json:"notificationType"` // "Bounce" | "Complaint" | "Delivery"
|
||||
Mail struct {
|
||||
MessageID string `json:"messageId"`
|
||||
} `json:"mail"`
|
||||
Bounce struct {
|
||||
BounceType string `json:"bounceType"` // "Permanent" | "Transient"
|
||||
BouncedRecipients []struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
} `json:"bouncedRecipients"`
|
||||
} `json:"bounce"`
|
||||
Complaint struct {
|
||||
ComplainedRecipients []struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
} `json:"complainedRecipients"`
|
||||
} `json:"complaint"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(envelope.Message), &inner); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid inner json")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
switch inner.NotificationType {
|
||||
case "Bounce":
|
||||
bt := "transient"
|
||||
if inner.Bounce.BounceType == "Permanent" {
|
||||
bt = "permanent"
|
||||
}
|
||||
_ = h.notifs.MarkBounce(ctx, inner.Mail.MessageID, bt)
|
||||
if h.suppress != nil && bt == "permanent" {
|
||||
for _, rcp := range inner.Bounce.BouncedRecipients {
|
||||
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses permanent bounce", notification.SuppressionBounce)
|
||||
}
|
||||
}
|
||||
case "Complaint":
|
||||
_ = h.notifs.MarkComplaint(ctx, inner.Mail.MessageID)
|
||||
if h.suppress != nil {
|
||||
for _, rcp := range inner.Complaint.ComplainedRecipients {
|
||||
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses complaint", notification.SuppressionComplaint)
|
||||
}
|
||||
}
|
||||
case "Delivery":
|
||||
_ = h.notifs.MarkDelivered(ctx, inner.Mail.MessageID)
|
||||
}
|
||||
h.logger.Info("ses notification", "type", inner.NotificationType, "message_id", inner.Mail.MessageID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Compile-time check that ctx is unused in package — silences linter on
|
||||
// some Go versions when the file would otherwise import context only for
|
||||
// the handler signatures.
|
||||
var _ = context.Background
|
||||
@@ -0,0 +1,51 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type wsTicketHandler struct {
|
||||
tickets *wsTicketStore
|
||||
events *storage.EventRepo
|
||||
}
|
||||
|
||||
type wsTicketResponse struct {
|
||||
Ticket string `json:"ticket"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// POST /auth/ws-ticket — requireAuth-protected; body { "event_id": "<uuid>" }.
|
||||
// Returns a single-use ticket valid for ~60 seconds. The frontend appends it
|
||||
// as `?ticket=…` on the WebSocket URL.
|
||||
func (h *wsTicketHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
EventID string `json:"event_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
eventID, ok := parseRawUUID(w, "event_id", req.EventID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tok, exp, err := h.tickets.Mint(hostID, eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to mint ticket")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, wsTicketResponse{Ticket: tok, ExpiresAt: exp})
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// wsTicketStore mints short-lived single-use tickets that authorise a
|
||||
// WebSocket handshake. The plan calls this option 3 in Block B: cookies
|
||||
// don't reach the WS handshake on cross-origin setups and a JWT in the URL
|
||||
// would leak to logs; a one-shot ticket sidesteps both.
|
||||
//
|
||||
// Block B keeps this in-process. When the API runs more than one replica
|
||||
// this needs to move to Redis (Block C territory).
|
||||
type wsTicketStore struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]wsTicketEntry
|
||||
ttl time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type wsTicketEntry struct {
|
||||
userID uuid.UUID
|
||||
eventID uuid.UUID
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newWSTicketStore(ttl time.Duration) *wsTicketStore {
|
||||
return &wsTicketStore{
|
||||
entries: make(map[string]wsTicketEntry),
|
||||
ttl: ttl,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Mint returns a fresh URL-safe ticket bound to userID + eventID.
|
||||
func (s *wsTicketStore) Mint(userID, eventID uuid.UUID) (string, time.Time, error) {
|
||||
buf := make([]byte, 24)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
tok := base64.RawURLEncoding.EncodeToString(buf)
|
||||
exp := s.now().Add(s.ttl)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked()
|
||||
s.entries[tok] = wsTicketEntry{userID: userID, eventID: eventID, expiresAt: exp}
|
||||
return tok, exp, nil
|
||||
}
|
||||
|
||||
// Consume removes the ticket and returns the bound (userID, eventID) if it
|
||||
// was valid. A ticket is single-use — replaying it fails.
|
||||
func (s *wsTicketStore) Consume(token string) (uuid.UUID, uuid.UUID, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
entry, ok := s.entries[token]
|
||||
if !ok {
|
||||
return uuid.Nil, uuid.Nil, false
|
||||
}
|
||||
delete(s.entries, token)
|
||||
if s.now().After(entry.expiresAt) {
|
||||
return uuid.Nil, uuid.Nil, false
|
||||
}
|
||||
return entry.userID, entry.eventID, true
|
||||
}
|
||||
|
||||
// sweepLocked drops expired entries opportunistically. Cheap because we
|
||||
// usually only hold dozens of tickets at a time.
|
||||
func (s *wsTicketStore) sweepLocked() {
|
||||
now := s.now()
|
||||
for k, v := range s.entries {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(s.entries, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
+23
-3
@@ -89,16 +89,36 @@ func (h *Hub) remove(eventID uuid.UUID, s *subscriber) {
|
||||
}
|
||||
|
||||
type wsHandler struct {
|
||||
logger *slog.Logger
|
||||
hub *Hub
|
||||
logger *slog.Logger
|
||||
hub *Hub
|
||||
tickets *wsTicketStore
|
||||
}
|
||||
|
||||
// GET /ws/events/{id} — dashboard live feed for one event.
|
||||
// GET /ws/events/{id}?ticket=... — dashboard live feed for one event.
|
||||
//
|
||||
// The handshake is authorised by a single-use ticket minted via
|
||||
// POST /auth/ws-ticket (option 3 from the Block B plan). The ticket binds
|
||||
// the connecting user to a specific event_id; we reject if either is
|
||||
// missing or doesn't match the URL path.
|
||||
func (h *wsHandler) handle(w http.ResponseWriter, r *http.Request) {
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
rawTicket := r.URL.Query().Get("ticket")
|
||||
if rawTicket == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing ticket")
|
||||
return
|
||||
}
|
||||
_, ticketEventID, valid := h.tickets.Consume(rawTicket)
|
||||
if !valid {
|
||||
writeError(w, http.StatusUnauthorized, "invalid or expired ticket")
|
||||
return
|
||||
}
|
||||
if ticketEventID != eventID {
|
||||
writeError(w, http.StatusForbidden, "ticket does not match event")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
// In dev the frontend runs on a different origin (localhost:3000 → localhost:8080).
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// EmailSender delivers transactional auth emails (verification, reset).
|
||||
// Block A ships LogSender so dev environments work without Twilio/SES.
|
||||
// Block D replaces this with a real SES-backed sender.
|
||||
type EmailSender interface {
|
||||
SendVerification(ctx context.Context, to, name, link string) error
|
||||
SendPasswordReset(ctx context.Context, to, name, link string) error
|
||||
}
|
||||
|
||||
type LogEmailSender struct {
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func (l LogEmailSender) SendVerification(_ context.Context, to, name, link string) error {
|
||||
l.Logger.Info("auth email (stub): verification",
|
||||
"to", to, "name", name, "link", link,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l LogEmailSender) SendPasswordReset(_ context.Context, to, name, link string) error {
|
||||
l.Logger.Info("auth email (stub): password reset",
|
||||
"to", to, "name", name, "link", link,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidJWT = errors.New("invalid token")
|
||||
ErrExpiredJWT = errors.New("token expired")
|
||||
)
|
||||
|
||||
type AccessClaims struct {
|
||||
UserID uuid.UUID `json:"sub_uuid"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTSigner struct {
|
||||
secret []byte
|
||||
ttl time.Duration
|
||||
issuer string
|
||||
parser *jwt.Parser
|
||||
}
|
||||
|
||||
func NewJWTSigner(secret string, ttl time.Duration, issuer string) (*JWTSigner, error) {
|
||||
if len(secret) < 32 {
|
||||
return nil, fmt.Errorf("jwt secret must be at least 32 bytes")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return nil, fmt.Errorf("jwt ttl must be positive")
|
||||
}
|
||||
return &JWTSigner{
|
||||
secret: []byte(secret),
|
||||
ttl: ttl,
|
||||
issuer: issuer,
|
||||
parser: jwt.NewParser(
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithIssuer(issuer),
|
||||
jwt.WithExpirationRequired(),
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *JWTSigner) Issue(userID uuid.UUID, now time.Time) (string, time.Time, error) {
|
||||
exp := now.Add(s.ttl)
|
||||
claims := AccessClaims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: s.issuer,
|
||||
Subject: userID.String(),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Second)),
|
||||
ExpiresAt: jwt.NewNumericDate(exp),
|
||||
ID: uuid.NewString(),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := tok.SignedString(s.secret)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return signed, exp, nil
|
||||
}
|
||||
|
||||
func (s *JWTSigner) Parse(raw string) (*AccessClaims, error) {
|
||||
claims := &AccessClaims{}
|
||||
tok, err := s.parser.ParseWithClaims(raw, claims, func(t *jwt.Token) (any, error) {
|
||||
return s.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrExpiredJWT
|
||||
}
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
if !tok.Valid {
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
if claims.UserID == uuid.Nil {
|
||||
// Fallback for tokens that only carry Subject (defensive — we always
|
||||
// set UserID on issue).
|
||||
parsed, perr := uuid.Parse(claims.Subject)
|
||||
if perr != nil {
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
claims.UserID = parsed
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *JWTSigner) TTL() time.Duration { return s.ttl }
|
||||
@@ -0,0 +1,82 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const testSecret = "test-secret-must-be-at-least-32-bytes-long-xx"
|
||||
|
||||
func TestJWTRoundTrip(t *testing.T) {
|
||||
s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test")
|
||||
if err != nil {
|
||||
t.Fatalf("signer: %v", err)
|
||||
}
|
||||
uid := uuid.New()
|
||||
tok, exp, err := s.Issue(uid, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("issue: %v", err)
|
||||
}
|
||||
if tok == "" {
|
||||
t.Fatal("empty token")
|
||||
}
|
||||
if time.Until(exp) <= 0 {
|
||||
t.Fatalf("expiry in past: %v", exp)
|
||||
}
|
||||
claims, err := s.Parse(tok)
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if claims.UserID != uid {
|
||||
t.Fatalf("user mismatch: got %s want %s", claims.UserID, uid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTExpired(t *testing.T) {
|
||||
s, err := NewJWTSigner(testSecret, 1*time.Second, "guestguard-test")
|
||||
if err != nil {
|
||||
t.Fatalf("signer: %v", err)
|
||||
}
|
||||
tok, _, err := s.Issue(uuid.New(), time.Now().Add(-1*time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("issue: %v", err)
|
||||
}
|
||||
if _, err := s.Parse(tok); !errors.Is(err, ErrExpiredJWT) {
|
||||
t.Fatalf("expected ErrExpiredJWT, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTTamper(t *testing.T) {
|
||||
s, err := NewJWTSigner(testSecret, 5*time.Minute, "guestguard-test")
|
||||
if err != nil {
|
||||
t.Fatalf("signer: %v", err)
|
||||
}
|
||||
tok, _, _ := s.Issue(uuid.New(), time.Now())
|
||||
// Flip a character in the signature segment.
|
||||
tampered := tok[:len(tok)-1] + "a"
|
||||
if tampered == tok {
|
||||
tampered = tok[:len(tok)-1] + "b"
|
||||
}
|
||||
if _, err := s.Parse(tampered); !errors.Is(err, ErrInvalidJWT) {
|
||||
t.Fatalf("expected ErrInvalidJWT, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTSecretTooShort(t *testing.T) {
|
||||
if _, err := NewJWTSigner("short", time.Minute, "x"); err == nil {
|
||||
t.Fatal("expected error for short secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpaqueTokenHashStable(t *testing.T) {
|
||||
raw, hash, err := NewOpaqueToken()
|
||||
if err != nil {
|
||||
t.Fatalf("mint: %v", err)
|
||||
}
|
||||
if got := HashOpaque(raw); got != hash {
|
||||
t.Fatalf("hash mismatch: got %s want %s", got, hash)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// LockoutTracker counts consecutive failed logins per email and trips a
|
||||
// lockout flag after a threshold. The lock is keyed by user_id (once we
|
||||
// know it from the email) so that resetting the password — which we do via
|
||||
// /auth/reset-password — can clear it cleanly.
|
||||
//
|
||||
// Why two keys? The failure counter must work even when the email maps to
|
||||
// no user (otherwise an attacker probing addresses just gets unlimited
|
||||
// tries). The lock flag only exists once we've matched an actual account.
|
||||
type LockoutTracker struct {
|
||||
client *redis.Client
|
||||
threshold int
|
||||
window time.Duration // how long failures linger before counters reset
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewLockoutTracker(client *redis.Client, threshold int, window time.Duration) *LockoutTracker {
|
||||
return &LockoutTracker{
|
||||
client: client,
|
||||
threshold: threshold,
|
||||
window: window,
|
||||
prefix: "auth",
|
||||
}
|
||||
}
|
||||
|
||||
func (t *LockoutTracker) failKey(email string) string {
|
||||
h := sha256.Sum256([]byte(strings.ToLower(strings.TrimSpace(email))))
|
||||
return fmt.Sprintf("%s:login_fail:%s", t.prefix, hex.EncodeToString(h[:]))
|
||||
}
|
||||
|
||||
func (t *LockoutTracker) lockKey(uid uuid.UUID) string {
|
||||
return fmt.Sprintf("%s:locked:%s", t.prefix, uid.String())
|
||||
}
|
||||
|
||||
// IsLocked reports whether the given user's account is currently locked.
|
||||
func (t *LockoutTracker) IsLocked(ctx context.Context, uid uuid.UUID) (bool, error) {
|
||||
if t == nil || t.client == nil {
|
||||
return false, nil
|
||||
}
|
||||
v, err := t.client.Exists(ctx, t.lockKey(uid)).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return v > 0, nil
|
||||
}
|
||||
|
||||
// RecordFailure increments the failure counter for the email and, if it
|
||||
// crosses the threshold, sets the lock flag for the given user id.
|
||||
// Returns (locked, error). A nil userID is fine — the counter still ticks
|
||||
// up so probing nonexistent accounts is also rate-limited.
|
||||
func (t *LockoutTracker) RecordFailure(ctx context.Context, email string, userID *uuid.UUID) (bool, error) {
|
||||
if t == nil || t.client == nil {
|
||||
return false, nil
|
||||
}
|
||||
key := t.failKey(email)
|
||||
n, err := t.client.Incr(ctx, key).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if n == 1 {
|
||||
_ = t.client.Expire(ctx, key, t.window).Err()
|
||||
}
|
||||
if int(n) >= t.threshold && userID != nil {
|
||||
// Keep the lock until password reset clears it. 7-day fallback TTL
|
||||
// so a permanently abandoned account doesn't pile up forever.
|
||||
if err := t.client.Set(ctx, t.lockKey(*userID), "1", 7*24*time.Hour).Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ClearForUser drops both the lock flag and any in-flight failure counter
|
||||
// for the user's email. Called from /auth/reset-password.
|
||||
func (t *LockoutTracker) ClearForUser(ctx context.Context, uid uuid.UUID, email string) error {
|
||||
if t == nil || t.client == nil {
|
||||
return nil
|
||||
}
|
||||
pipe := t.client.Pipeline()
|
||||
pipe.Del(ctx, t.lockKey(uid))
|
||||
pipe.Del(ctx, t.failKey(email))
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearOnSuccess drops only the failure counter — used after a successful
|
||||
// login to forgive prior typos.
|
||||
func (t *LockoutTracker) ClearOnSuccess(ctx context.Context, email string) {
|
||||
if t == nil || t.client == nil {
|
||||
return
|
||||
}
|
||||
_ = t.client.Del(ctx, t.failKey(email)).Err()
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const bcryptCost = 12
|
||||
|
||||
var (
|
||||
ErrPasswordMismatch = errors.New("password mismatch")
|
||||
ErrPasswordTooShort = errors.New("password must be at least 8 characters")
|
||||
ErrPasswordTooLong = errors.New("password must be at most 72 characters")
|
||||
)
|
||||
|
||||
type PasswordHasher struct {
|
||||
cost int
|
||||
}
|
||||
|
||||
func NewPasswordHasher() *PasswordHasher {
|
||||
return &PasswordHasher{cost: bcryptCost}
|
||||
}
|
||||
|
||||
func (h *PasswordHasher) Hash(raw string) (string, error) {
|
||||
if err := ValidatePassword(raw); err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := bcrypt.GenerateFromPassword([]byte(raw), h.cost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *PasswordHasher) Verify(hash, raw string) error {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(raw)); err != nil {
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return ErrPasswordMismatch
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePassword enforces a minimum length and the bcrypt-imposed maximum.
|
||||
// bcrypt silently truncates inputs over 72 bytes, which would let a user set
|
||||
// a 100-character password and successfully log in with the first 72; reject
|
||||
// at the boundary instead.
|
||||
func ValidatePassword(raw string) error {
|
||||
if len(raw) < 8 {
|
||||
return ErrPasswordTooShort
|
||||
}
|
||||
if len(raw) > 72 {
|
||||
return ErrPasswordTooLong
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordHasherRoundTrip(t *testing.T) {
|
||||
h := NewPasswordHasher()
|
||||
hash, err := h.Hash("correct-horse-battery-staple")
|
||||
if err != nil {
|
||||
t.Fatalf("hash: %v", err)
|
||||
}
|
||||
if hash == "" {
|
||||
t.Fatal("empty hash")
|
||||
}
|
||||
if err := h.Verify(hash, "correct-horse-battery-staple"); err != nil {
|
||||
t.Fatalf("verify correct: %v", err)
|
||||
}
|
||||
if err := h.Verify(hash, "wrong"); !errors.Is(err, ErrPasswordMismatch) {
|
||||
t.Fatalf("verify wrong: got %v want ErrPasswordMismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pw string
|
||||
want error
|
||||
}{
|
||||
{"too short", "1234567", ErrPasswordTooShort},
|
||||
{"min length ok", "12345678", nil},
|
||||
{"too long", strings.Repeat("a", 73), ErrPasswordTooLong},
|
||||
{"max length ok", strings.Repeat("a", 72), nil},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := ValidatePassword(tc.pw); !errors.Is(got, tc.want) {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// NewOpaqueToken returns a 32-byte URL-safe random token plus its SHA-256 hex
|
||||
// digest. The raw value is shown once (in a link); only the digest is stored.
|
||||
func NewOpaqueToken() (raw, hash string, err error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
raw = base64.RawURLEncoding.EncodeToString(buf)
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
hash = hex.EncodeToString(sum[:])
|
||||
return raw, hash, nil
|
||||
}
|
||||
|
||||
func HashOpaque(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/stripe/stripe-go/v82"
|
||||
"github.com/stripe/stripe-go/v82/billingportal/session"
|
||||
csession "github.com/stripe/stripe-go/v82/checkout/session"
|
||||
"github.com/stripe/stripe-go/v82/customer"
|
||||
"github.com/stripe/stripe-go/v82/webhook"
|
||||
)
|
||||
|
||||
// Client wraps the Stripe SDK with the subset of calls the API needs.
|
||||
// Concentrates env-var reads + price-ID lookups in one place so handlers
|
||||
// stay focused on HTTP, not Stripe plumbing.
|
||||
type Client struct {
|
||||
secretKey string
|
||||
webhookSecret string
|
||||
prices map[Tier]string
|
||||
}
|
||||
|
||||
// Config is the env-derived configuration the Client needs.
|
||||
type Config struct {
|
||||
SecretKey string
|
||||
WebhookSecret string
|
||||
PriceProMonthly string
|
||||
PriceBusiness string
|
||||
}
|
||||
|
||||
// NewClient validates required fields and returns a configured client.
|
||||
// Returns (nil, nil) when SecretKey is empty — callers treat that as
|
||||
// "billing disabled" and degrade gracefully (free tier for everyone, no
|
||||
// /billing/* endpoints exposed).
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
if cfg.SecretKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
stripe.Key = cfg.SecretKey
|
||||
|
||||
c := &Client{
|
||||
secretKey: cfg.SecretKey,
|
||||
webhookSecret: cfg.WebhookSecret,
|
||||
prices: map[Tier]string{
|
||||
TierPro: cfg.PriceProMonthly,
|
||||
TierBusiness: cfg.PriceBusiness,
|
||||
},
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Enabled reports whether the client was constructed with a Stripe key.
|
||||
func (c *Client) Enabled() bool { return c != nil && c.secretKey != "" }
|
||||
|
||||
// PriceFor returns the Stripe Price ID for a tier or an error if it
|
||||
// hasn't been configured — checkout will fail loudly rather than
|
||||
// silently send the customer to an empty checkout page.
|
||||
func (c *Client) PriceFor(tier Tier) (string, error) {
|
||||
id, ok := c.prices[tier]
|
||||
if !ok || id == "" {
|
||||
return "", fmt.Errorf("billing: no Stripe price configured for tier %q", tier)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// CreateOrGetCustomer returns the Stripe customer id for the given
|
||||
// (user_id, email). If `existingID` is non-empty we trust it; otherwise
|
||||
// we create a new Stripe customer and let the caller persist the id.
|
||||
func (c *Client) CreateOrGetCustomer(userID, email, name, existingID string) (string, error) {
|
||||
if !c.Enabled() {
|
||||
return "", errors.New("billing: disabled")
|
||||
}
|
||||
if existingID != "" {
|
||||
return existingID, nil
|
||||
}
|
||||
params := &stripe.CustomerParams{
|
||||
Email: stripe.String(email),
|
||||
Name: stripe.String(name),
|
||||
Metadata: map[string]string{"gg_user_id": userID},
|
||||
}
|
||||
cust, err := customer.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stripe customer create: %w", err)
|
||||
}
|
||||
return cust.ID, nil
|
||||
}
|
||||
|
||||
// CheckoutSessionParams collects the inputs CreateCheckoutSession needs.
|
||||
// Keeping them in a struct so future fields (coupon codes, trial periods,
|
||||
// referral metadata) drop in without breaking callers.
|
||||
type CheckoutSessionParams struct {
|
||||
CustomerID string
|
||||
PriceID string
|
||||
SuccessURL string
|
||||
CancelURL string
|
||||
}
|
||||
|
||||
// CreateCheckoutSession returns the URL the frontend redirects the user
|
||||
// to. Subscription mode — recurring billing for Pro/Business.
|
||||
func (c *Client) CreateCheckoutSession(p CheckoutSessionParams) (string, error) {
|
||||
if !c.Enabled() {
|
||||
return "", errors.New("billing: disabled")
|
||||
}
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
Customer: stripe.String(p.CustomerID),
|
||||
SuccessURL: stripe.String(p.SuccessURL),
|
||||
CancelURL: stripe.String(p.CancelURL),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{{
|
||||
Price: stripe.String(p.PriceID),
|
||||
Quantity: stripe.Int64(1),
|
||||
}},
|
||||
AllowPromotionCodes: stripe.Bool(true),
|
||||
}
|
||||
sess, err := csession.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stripe checkout session: %w", err)
|
||||
}
|
||||
return sess.URL, nil
|
||||
}
|
||||
|
||||
// CreatePortalSession returns a URL to the Stripe-hosted customer
|
||||
// portal so the user can manage payment methods, cancel, view invoices.
|
||||
func (c *Client) CreatePortalSession(customerID, returnURL string) (string, error) {
|
||||
if !c.Enabled() {
|
||||
return "", errors.New("billing: disabled")
|
||||
}
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(customerID),
|
||||
ReturnURL: stripe.String(returnURL),
|
||||
}
|
||||
sess, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stripe portal session: %w", err)
|
||||
}
|
||||
return sess.URL, nil
|
||||
}
|
||||
|
||||
// VerifyWebhook validates the Stripe signature header and returns the
|
||||
// parsed event. Refuses to verify when no webhook secret is configured —
|
||||
// no shared secret means anyone can POST forged events, so the route
|
||||
// should reject everything in that case (the caller does that check).
|
||||
//
|
||||
// We pass IgnoreAPIVersionMismatch because Stripe accounts can be on a
|
||||
// newer API version than the SDK we're built against. Event payloads
|
||||
// are designed to be forward-compatible — the SDK warns about the skew
|
||||
// but the deserialised event is still safe to use. Strict matching
|
||||
// would mean we'd have to upgrade the SDK in lockstep with whatever
|
||||
// Stripe rolls out, defeating the point of having an SDK.
|
||||
func (c *Client) VerifyWebhook(body []byte, sigHeader string) (stripe.Event, error) {
|
||||
if c.webhookSecret == "" {
|
||||
return stripe.Event{}, errors.New("billing: no webhook secret configured")
|
||||
}
|
||||
return webhook.ConstructEventWithOptions(body, sigHeader, c.webhookSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
// Package billing models GuestGuard's subscription tiers and Stripe
|
||||
// integration. Plan limits live here so handler + middleware layers don't
|
||||
// hard-code numbers — change a value, restart the API, and the cap moves.
|
||||
package billing
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Tier is the user's current subscription tier. Stored as text in the
|
||||
// subscriptions table.
|
||||
type Tier string
|
||||
|
||||
const (
|
||||
TierFree Tier = "free"
|
||||
TierPro Tier = "pro"
|
||||
TierBusiness Tier = "business"
|
||||
)
|
||||
|
||||
func (t Tier) Valid() bool {
|
||||
switch t {
|
||||
case TierFree, TierPro, TierBusiness:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Limits enforces what a tier may do. -1 means unlimited.
|
||||
type Limits struct {
|
||||
EventsPerMonth int
|
||||
GuestsPerEvent int
|
||||
}
|
||||
|
||||
// TierLimits is the canonical plan-limits table. Matches docs/TIER1_PLAN.md
|
||||
// Block F pricing (placeholder until market validation).
|
||||
var TierLimits = map[Tier]Limits{
|
||||
TierFree: {EventsPerMonth: 1, GuestsPerEvent: 50},
|
||||
TierPro: {EventsPerMonth: 10, GuestsPerEvent: 1000},
|
||||
TierBusiness: {EventsPerMonth: -1, GuestsPerEvent: 5000},
|
||||
}
|
||||
|
||||
// LimitsFor returns the limits for a tier, defaulting to Free for unknown
|
||||
// strings — defensive, so a typo or future-tier in the DB never grants
|
||||
// unlimited access.
|
||||
func LimitsFor(t Tier) Limits {
|
||||
if l, ok := TierLimits[t]; ok {
|
||||
return l
|
||||
}
|
||||
return TierLimits[TierFree]
|
||||
}
|
||||
|
||||
// StatusGrantsAccess returns true if a subscription with this status
|
||||
// should be treated as paid. Mirrors the unique-index predicate in
|
||||
// 0005_billing.up.sql.
|
||||
func StatusGrantsAccess(status string) bool {
|
||||
switch status {
|
||||
case "active", "past_due", "trialing":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LimitError describes a denied-by-policy outcome. The handler layer
|
||||
// turns this into a 402 with a JSON body the frontend uses to render
|
||||
// the upgrade modal.
|
||||
type LimitError struct {
|
||||
Reason string
|
||||
Tier Tier
|
||||
Limit int
|
||||
Used int
|
||||
}
|
||||
|
||||
func (e *LimitError) Error() string {
|
||||
return fmt.Sprintf("billing: %s (tier=%s used=%d limit=%d)", e.Reason, e.Tier, e.Used, e.Limit)
|
||||
}
|
||||
@@ -12,11 +12,56 @@ type Config struct {
|
||||
HTTPAddr string
|
||||
DatabaseURL string
|
||||
NATSURL string
|
||||
RedisAddr string
|
||||
FraudGRPCAddr string
|
||||
FraudGRPCTimeout time.Duration
|
||||
ShutdownTimeout time.Duration
|
||||
TokenSecret string
|
||||
TokenTTL time.Duration
|
||||
|
||||
// Auth
|
||||
JWTSecret string
|
||||
JWTIssuer string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
PublicBaseURL string
|
||||
RefreshCookieDomain string
|
||||
RefreshCookieSecure bool
|
||||
|
||||
// Notifications (Block D). Empty values leave the log-stub adapter in
|
||||
// place so the service still boots without AWS / Twilio creds.
|
||||
SESRegion string
|
||||
SESFromEmail string
|
||||
SESFromName string
|
||||
SESConfigurationSet string
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFromEmail string
|
||||
SMTPFromName string
|
||||
SMTPTLS string // "starttls" | "implicit" | "none" (default starttls)
|
||||
|
||||
ResendAPIKey string
|
||||
ResendFromEmail string
|
||||
ResendFromName string
|
||||
|
||||
TwilioAccountSID string
|
||||
TwilioAuthToken string
|
||||
TwilioFromNumber string
|
||||
|
||||
// Billing (Block F). Empty StripeSecretKey leaves billing disabled —
|
||||
// all users get free-tier limits, /billing/* returns 503. Lets the
|
||||
// service boot in dev without Stripe creds.
|
||||
StripeSecretKey string
|
||||
StripeWebhookSecret string
|
||||
StripePricePro string // Stripe Price ID for Pro monthly
|
||||
StripePriceBusiness string // Stripe Price ID for Business monthly
|
||||
|
||||
UnsubscribeSecret string // HMAC key for signing unsubscribe links
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -25,11 +70,50 @@ func Load() (*Config, error) {
|
||||
HTTPAddr: getenv("GG_HTTP_ADDR", ":8080"),
|
||||
DatabaseURL: getenv("GG_DATABASE_URL", "postgres://guestguard:guestguard@localhost:5432/guestguard?sslmode=disable"),
|
||||
NATSURL: getenv("GG_NATS_URL", "nats://localhost:4222"),
|
||||
RedisAddr: getenv("GG_REDIS_ADDR", "localhost:6379"),
|
||||
FraudGRPCAddr: getenv("GG_FRAUD_GRPC_ADDR", "fraud-engine:9091"),
|
||||
FraudGRPCTimeout: getenvDuration("GG_FRAUD_GRPC_TIMEOUT", 250*time.Millisecond),
|
||||
ShutdownTimeout: getenvDuration("GG_SHUTDOWN_TIMEOUT", 15*time.Second),
|
||||
TokenSecret: os.Getenv("GG_TOKEN_SECRET"),
|
||||
TokenTTL: getenvDuration("GG_TOKEN_TTL", 30*24*time.Hour),
|
||||
|
||||
JWTSecret: os.Getenv("GG_JWT_SECRET"),
|
||||
JWTIssuer: getenv("GG_JWT_ISSUER", "guestguard"),
|
||||
AccessTokenTTL: getenvDuration("GG_ACCESS_TOKEN_TTL", 15*time.Minute),
|
||||
RefreshTokenTTL: getenvDuration("GG_REFRESH_TOKEN_TTL", 30*24*time.Hour),
|
||||
EmailVerificationTTL: getenvDuration("GG_EMAIL_VERIFICATION_TTL", 24*time.Hour),
|
||||
PasswordResetTTL: getenvDuration("GG_PASSWORD_RESET_TTL", 1*time.Hour),
|
||||
PublicBaseURL: getenv("GG_PUBLIC_BASE_URL", "http://localhost:3000"),
|
||||
RefreshCookieDomain: os.Getenv("GG_REFRESH_COOKIE_DOMAIN"),
|
||||
RefreshCookieSecure: getenvBool("GG_REFRESH_COOKIE_SECURE", false),
|
||||
|
||||
SESRegion: getenv("GG_SES_REGION", "us-east-1"),
|
||||
SESFromEmail: os.Getenv("GG_SES_FROM_EMAIL"),
|
||||
SESFromName: getenv("GG_SES_FROM_NAME", "GuestGuard"),
|
||||
SESConfigurationSet: os.Getenv("GG_SES_CONFIGURATION_SET"),
|
||||
|
||||
SMTPHost: os.Getenv("GG_SMTP_HOST"),
|
||||
SMTPPort: getenvInt("GG_SMTP_PORT", 587),
|
||||
SMTPUsername: os.Getenv("GG_SMTP_USERNAME"),
|
||||
SMTPPassword: os.Getenv("GG_SMTP_PASSWORD"),
|
||||
SMTPFromEmail: os.Getenv("GG_SMTP_FROM_EMAIL"),
|
||||
SMTPFromName: getenv("GG_SMTP_FROM_NAME", "GuestGuard"),
|
||||
SMTPTLS: getenv("GG_SMTP_TLS", "starttls"),
|
||||
|
||||
ResendAPIKey: os.Getenv("GG_RESEND_API_KEY"),
|
||||
ResendFromEmail: os.Getenv("GG_RESEND_FROM_EMAIL"),
|
||||
ResendFromName: getenv("GG_RESEND_FROM_NAME", "GuestGuard"),
|
||||
|
||||
TwilioAccountSID: os.Getenv("GG_TWILIO_ACCOUNT_SID"),
|
||||
TwilioAuthToken: os.Getenv("GG_TWILIO_AUTH_TOKEN"),
|
||||
TwilioFromNumber: os.Getenv("GG_TWILIO_FROM_NUMBER"),
|
||||
|
||||
StripeSecretKey: os.Getenv("GG_STRIPE_SECRET_KEY"),
|
||||
StripeWebhookSecret: os.Getenv("GG_STRIPE_WEBHOOK_SECRET"),
|
||||
StripePricePro: os.Getenv("GG_STRIPE_PRICE_PRO"),
|
||||
StripePriceBusiness: os.Getenv("GG_STRIPE_PRICE_BUSINESS"),
|
||||
|
||||
UnsubscribeSecret: os.Getenv("GG_UNSUBSCRIBE_SECRET"),
|
||||
}
|
||||
|
||||
if cfg.Env == "production" && cfg.TokenSecret == "" {
|
||||
@@ -39,9 +123,52 @@ func Load() (*Config, error) {
|
||||
cfg.TokenSecret = "dev-only-insecure-secret-change-me"
|
||||
}
|
||||
|
||||
if cfg.Env == "production" && cfg.JWTSecret == "" {
|
||||
return nil, fmt.Errorf("GG_JWT_SECRET is required in production")
|
||||
}
|
||||
if cfg.JWTSecret == "" {
|
||||
cfg.JWTSecret = "dev-only-insecure-jwt-secret-change-me-32+bytes"
|
||||
}
|
||||
if len(cfg.JWTSecret) < 32 {
|
||||
return nil, fmt.Errorf("GG_JWT_SECRET must be at least 32 bytes")
|
||||
}
|
||||
|
||||
if cfg.UnsubscribeSecret == "" {
|
||||
// Same dev fallback shape as the other secrets — production refuses
|
||||
// to boot without it.
|
||||
if cfg.Env == "production" {
|
||||
return nil, fmt.Errorf("GG_UNSUBSCRIBE_SECRET is required in production")
|
||||
}
|
||||
cfg.UnsubscribeSecret = "dev-only-insecure-unsubscribe-secret-change-me"
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getenvInt(key string, fallback int) int {
|
||||
v, ok := os.LookupEnv(key)
|
||||
if !ok || v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getenvBool(key string, fallback bool) bool {
|
||||
v, ok := os.LookupEnv(key)
|
||||
if !ok || v == "" {
|
||||
return fallback
|
||||
}
|
||||
b, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func getenv(key, fallback string) string {
|
||||
if v, ok := os.LookupEnv(key); ok && v != "" {
|
||||
return v
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
// Package csvimport parses a guest-list CSV into structured rows, with
|
||||
// tolerant header detection (Excel, Numbers, Google Sheets variants) and
|
||||
// per-row validation. Streaming-friendly so a 5,000-row import doesn't
|
||||
// load the entire file into a slice before we know if column 1 is junk.
|
||||
package csvimport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// Row is a single validated guest. Empty Email / Phone are allowed (a
|
||||
// phone-only or name-only guest is valid per the plan).
|
||||
type Row struct {
|
||||
Name string
|
||||
Email string
|
||||
Phone string
|
||||
PlusOnes int
|
||||
}
|
||||
|
||||
// RowError flags one row with the human-readable reason it can't be
|
||||
// imported. The line number is 1-based and matches the source CSV
|
||||
// (header counts as line 1, first data row is line 2) so the frontend
|
||||
// can highlight the offending row.
|
||||
type RowError struct {
|
||||
Row int `json:"row"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// Result is the outcome of one parse pass.
|
||||
type Result struct {
|
||||
Rows []Row `json:"rows,omitempty"`
|
||||
Errors []RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"` // total data rows seen (excluding header)
|
||||
}
|
||||
|
||||
// Options tune limits + behaviour.
|
||||
type Options struct {
|
||||
MaxRows int // hard cap; rows beyond MaxRows return an error instead of being silently dropped
|
||||
}
|
||||
|
||||
const DefaultMaxRows = 5000
|
||||
|
||||
// Strict E.164: optional leading +, then a non-zero leading digit (country
|
||||
// codes never start with 0), followed by 6–14 more digits — total 7–15
|
||||
// significant digits. Spaces / dashes / parens are tolerated by stripping
|
||||
// before validation, but local-format numbers like "0244…" or "07700…"
|
||||
// are rejected here so the host fixes them at upload time rather than at
|
||||
// WhatsApp-send time.
|
||||
var phoneRe = regexp.MustCompile(`^\+?[1-9][0-9]{6,14}$`)
|
||||
|
||||
// Parse reads a CSV from r and returns the parsed result. Encoding is
|
||||
// auto-detected: UTF-8 with or without BOM, plus UTF-16 LE/BE BOMs
|
||||
// (commonly produced by Mac Numbers exports).
|
||||
func Parse(r io.Reader, opt Options) (*Result, error) {
|
||||
max := opt.MaxRows
|
||||
if max <= 0 {
|
||||
max = DefaultMaxRows
|
||||
}
|
||||
|
||||
rd, err := decodingReader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
csvr := csv.NewReader(rd)
|
||||
csvr.FieldsPerRecord = -1 // tolerate ragged rows; we re-validate column count ourselves
|
||||
csvr.TrimLeadingSpace = true
|
||||
|
||||
header, err := csvr.Read()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil, errors.New("csv is empty")
|
||||
}
|
||||
return nil, fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
cols, err := detectColumns(header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &Result{Rows: make([]Row, 0, 64)}
|
||||
lineNo := 1 // header was line 1
|
||||
for {
|
||||
rec, err := csvr.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
lineNo++
|
||||
if err != nil {
|
||||
out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: fmt.Sprintf("malformed csv: %v", err)})
|
||||
continue
|
||||
}
|
||||
out.TotalCount++
|
||||
if out.TotalCount > max {
|
||||
return nil, fmt.Errorf("import exceeds maximum of %d rows", max)
|
||||
}
|
||||
|
||||
// Skip fully-empty rows silently — these appear at the end of
|
||||
// Excel exports a lot.
|
||||
if rowEmpty(rec) {
|
||||
out.TotalCount-- // don't count it
|
||||
continue
|
||||
}
|
||||
|
||||
row, rerr := buildRow(rec, cols)
|
||||
if rerr != "" {
|
||||
out.Errors = append(out.Errors, RowError{Row: lineNo, Reason: rerr})
|
||||
continue
|
||||
}
|
||||
out.Rows = append(out.Rows, row)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func rowEmpty(rec []string) bool {
|
||||
for _, v := range rec {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// decodingReader strips a UTF-8 BOM and decodes UTF-16 LE/BE when their
|
||||
// BOM is present, returning a UTF-8 reader. Other byte orders fall through
|
||||
// as raw UTF-8.
|
||||
func decodingReader(r io.Reader) (*bufio.Reader, error) {
|
||||
br := bufio.NewReader(r)
|
||||
bom, err := br.Peek(3)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case len(bom) >= 3 && bom[0] == 0xEF && bom[1] == 0xBB && bom[2] == 0xBF:
|
||||
_, _ = br.Discard(3)
|
||||
return br, nil
|
||||
case len(bom) >= 2 && bom[0] == 0xFF && bom[1] == 0xFE:
|
||||
_, _ = br.Discard(2)
|
||||
dec := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()
|
||||
return bufio.NewReader(transform.NewReader(br, dec)), nil
|
||||
case len(bom) >= 2 && bom[0] == 0xFE && bom[1] == 0xFF:
|
||||
_, _ = br.Discard(2)
|
||||
dec := unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder()
|
||||
return bufio.NewReader(transform.NewReader(br, dec)), nil
|
||||
}
|
||||
return br, nil
|
||||
}
|
||||
|
||||
// columnSet records which column index each known field lives in. -1 means
|
||||
// the column was not supplied; only Name is mandatory.
|
||||
type columnSet struct {
|
||||
name, email, phone, plusOnes int
|
||||
}
|
||||
|
||||
func detectColumns(header []string) (columnSet, error) {
|
||||
cs := columnSet{name: -1, email: -1, phone: -1, plusOnes: -1}
|
||||
for i, raw := range header {
|
||||
key := normaliseHeader(raw)
|
||||
switch key {
|
||||
case "name", "guestname", "fullname":
|
||||
cs.name = i
|
||||
case "email", "emailaddress", "e-mail":
|
||||
cs.email = i
|
||||
case "phone", "telephone", "mobile", "phonenumber":
|
||||
cs.phone = i
|
||||
case "plusones", "plus1", "plus-one", "plus-ones", "+1", "guests", "additionalguests":
|
||||
cs.plusOnes = i
|
||||
}
|
||||
}
|
||||
if cs.name < 0 {
|
||||
return cs, fmt.Errorf("required column 'name' not found in header: %v", header)
|
||||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func normaliseHeader(s string) string {
|
||||
s = strings.ToLower(strings.TrimSpace(s))
|
||||
// Drop spaces + underscores. Keep `+`, `-` so "+1" / "plus-one" still
|
||||
// match exactly.
|
||||
return strings.NewReplacer(" ", "", "_", "").Replace(s)
|
||||
}
|
||||
|
||||
func buildRow(rec []string, cs columnSet) (Row, string) {
|
||||
get := func(i int) string {
|
||||
if i < 0 || i >= len(rec) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(rec[i])
|
||||
}
|
||||
row := Row{
|
||||
Name: get(cs.name),
|
||||
Email: strings.ToLower(get(cs.email)),
|
||||
Phone: get(cs.phone),
|
||||
}
|
||||
if row.Name == "" {
|
||||
return row, "name is required"
|
||||
}
|
||||
|
||||
if row.Email != "" {
|
||||
if _, err := mail.ParseAddress(row.Email); err != nil {
|
||||
return row, "invalid email"
|
||||
}
|
||||
}
|
||||
if row.Phone != "" {
|
||||
stripped := stripPhone(row.Phone)
|
||||
if !phoneRe.MatchString(stripped) {
|
||||
return row, "phone must be in international format with country code (e.g. +447700900123) — local numbers starting with 0 won't work for SMS or WhatsApp"
|
||||
}
|
||||
// Normalise: ensure stored form always starts with "+".
|
||||
if !strings.HasPrefix(stripped, "+") {
|
||||
stripped = "+" + stripped
|
||||
}
|
||||
row.Phone = stripped
|
||||
}
|
||||
if raw := get(cs.plusOnes); raw != "" {
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 0 {
|
||||
return row, "plus_ones must be a non-negative integer"
|
||||
}
|
||||
row.PlusOnes = n
|
||||
}
|
||||
return row, ""
|
||||
}
|
||||
|
||||
var phoneStripper = strings.NewReplacer(" ", "", "-", "", "(", "", ")", "", " ", "")
|
||||
|
||||
func stripPhone(s string) string {
|
||||
return phoneStripper.Replace(s)
|
||||
}
|
||||
|
||||
// TemplateCSV is the sample file served at /events/{id}/guests/import/template.
|
||||
// Phone numbers MUST include the country code (e.g. +44 for UK, +233 for
|
||||
// Ghana). Local-format numbers like "0244..." or "07700..." will be
|
||||
// rejected at upload — the sample below shows the expected shape.
|
||||
const TemplateCSV = "name,email,phone,plus_ones\n" +
|
||||
"Alex Doe,alex@example.com,+447700900123,1\n" +
|
||||
"Sam Patel,sam@example.com,,0\n" +
|
||||
"Jordan Lee,,+15551234567,2\n" +
|
||||
"Mira Patel,mira@example.com,+233244123456,0\n"
|
||||
@@ -0,0 +1,157 @@
|
||||
package csvimport
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHappyPath(t *testing.T) {
|
||||
in := `name,email,phone,plus_ones
|
||||
Alex Doe,alex@example.com,+447700900123,1
|
||||
Sam Patel,SAM@example.com,,0
|
||||
Jordan Lee,,+1 (555) 123-4567,2
|
||||
`
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if got, want := len(r.Rows), 3; got != want {
|
||||
t.Fatalf("rows: got %d want %d (errors=%+v)", got, want, r.Errors)
|
||||
}
|
||||
if r.Rows[1].Email != "sam@example.com" {
|
||||
t.Errorf("email not lowercased: %q", r.Rows[1].Email)
|
||||
}
|
||||
if r.Rows[2].Phone != "+15551234567" {
|
||||
t.Errorf("phone not stripped: %q", r.Rows[2].Phone)
|
||||
}
|
||||
if r.Rows[0].Phone != "+447700900123" {
|
||||
t.Errorf("phone should keep leading +: %q", r.Rows[0].Phone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePhoneNormalisedToPlus(t *testing.T) {
|
||||
// E.164 without explicit "+" is accepted and normalised to include one.
|
||||
in := "name,phone\nAlex,447700900123\n"
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 || r.Rows[0].Phone != "+447700900123" {
|
||||
t.Fatalf("expected normalised phone, got: rows=%+v errors=%+v", r.Rows, r.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePhoneRejectsLocalFormat(t *testing.T) {
|
||||
// Local UK / GH style numbers (leading 0, no country code) must be
|
||||
// rejected — they break SMS routing and WhatsApp click-to-chat.
|
||||
in := `name,phone
|
||||
UK Local,07700900123
|
||||
GH Local,0244123456
|
||||
With Plus And Zero,+0244123456
|
||||
`
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Errors) != 3 {
|
||||
t.Fatalf("expected 3 errors for local-format phones, got %d: %+v", len(r.Errors), r.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHeaderVariants(t *testing.T) {
|
||||
cases := []string{
|
||||
"Name,Email,Phone,Plus Ones\nMira,m@x.com,,1\n",
|
||||
"Guest Name,E-Mail,Telephone,+1\nMira,m@x.com,,1\n",
|
||||
"full_name,email_address,mobile,plusones\nMira,m@x.com,,1\n",
|
||||
}
|
||||
for i, in := range cases {
|
||||
t.Run(string(rune('a'+i)), func(t *testing.T) {
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 {
|
||||
t.Fatalf("rows=%d errors=%+v", len(r.Rows), r.Errors)
|
||||
}
|
||||
if r.Rows[0].PlusOnes != 1 {
|
||||
t.Errorf("plusones not detected: %+v", r.Rows[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRowValidation(t *testing.T) {
|
||||
in := `name,email,phone,plus_ones
|
||||
,a@x.com,,1
|
||||
Valid Guest,not-an-email,,0
|
||||
Phone Person,,abc,0
|
||||
Negative,,,-1
|
||||
Email Only,e@x.com,,
|
||||
`
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Errors) != 4 {
|
||||
t.Fatalf("expected 4 errors, got %d: %+v", len(r.Errors), r.Errors)
|
||||
}
|
||||
// "Email Only" row is valid: name present, email parses, phone+plus blank.
|
||||
if len(r.Rows) != 1 || r.Rows[0].Name != "Email Only" {
|
||||
t.Fatalf("expected 1 valid row, got %+v", r.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUTF8BOM(t *testing.T) {
|
||||
in := "\xEF\xBB\xBFname,email\nMira,m@x.com\n"
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" {
|
||||
t.Fatalf("BOM not stripped: %+v", r.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUTF16LE(t *testing.T) {
|
||||
// "name,email\nMira,m@x.com\n" in UTF-16 LE with BOM.
|
||||
in := []byte{0xFF, 0xFE}
|
||||
for _, r := range "name,email\nMira,m@x.com\n" {
|
||||
in = append(in, byte(r), byte(r>>8))
|
||||
}
|
||||
r, err := Parse(strings.NewReader(string(in)), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(r.Rows) != 1 || r.Rows[0].Name != "Mira" {
|
||||
t.Fatalf("UTF-16 LE not decoded: %+v", r.Rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEmptyTrailingRows(t *testing.T) {
|
||||
in := "name,email\nMira,m@x.com\n,,\n\n,\n"
|
||||
r, err := Parse(strings.NewReader(in), Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if r.TotalCount != 1 {
|
||||
t.Fatalf("trailing blanks counted: TotalCount=%d", r.TotalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMissingNameHeader(t *testing.T) {
|
||||
in := "email,phone\na@x.com,\n"
|
||||
if _, err := Parse(strings.NewReader(in), Options{}); err == nil {
|
||||
t.Fatal("expected error for missing name column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMaxRows(t *testing.T) {
|
||||
var b strings.Builder
|
||||
b.WriteString("name\n")
|
||||
for i := 0; i < 11; i++ {
|
||||
b.WriteString("X\n")
|
||||
}
|
||||
if _, err := Parse(strings.NewReader(b.String()), Options{MaxRows: 10}); err == nil {
|
||||
t.Fatal("expected error when exceeding MaxRows")
|
||||
}
|
||||
}
|
||||
+26
-7
@@ -8,14 +8,33 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PasswordHash string `json:"-"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
|
||||
DeletedAt *time.Time `json:"-"`
|
||||
TermsAcceptedAt *time.Time `json:"terms_accepted_at,omitempty"`
|
||||
PrivacyPolicyAcceptedAt *time.Time `json:"privacy_policy_accepted_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TermsAccepted reports whether the user has accepted both the terms
|
||||
// of service and the privacy policy. Both must be present for the user
|
||||
// to use the dashboard once enforcement is enabled.
|
||||
func (u *User) TermsAccepted() bool {
|
||||
return u != nil && u.TermsAcceptedAt != nil && u.PrivacyPolicyAcceptedAt != nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrEmailTaken = errors.New("email already in use")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrEmailTaken = errors.New("email already in use")
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
ErrAuthTokenNotFound = errors.New("auth token not found")
|
||||
ErrAuthTokenConsumed = errors.New("auth token already used")
|
||||
ErrAuthTokenExpired = errors.New("auth token expired")
|
||||
ErrRefreshTokenRevoked = errors.New("refresh token revoked")
|
||||
ErrAccountLocked = errors.New("account locked due to too many failed login attempts")
|
||||
)
|
||||
|
||||
@@ -92,6 +92,15 @@ func (c *Client) PublishRSVPConfirmed(ctx context.Context, evt RSVPConfirmed) er
|
||||
return c.publishJSON(ctx, SubjectRSVPConfirmed, evt, evt.RSVPID)
|
||||
}
|
||||
|
||||
func (c *Client) PublishInvitationSend(ctx context.Context, evt InvitationSend) error {
|
||||
if evt.IssuedAt.IsZero() {
|
||||
evt.IssuedAt = time.Now().UTC()
|
||||
}
|
||||
// Dedup by token id — re-issuing a token (currently disallowed by the
|
||||
// unique constraint, but defensive) won't double-send the email.
|
||||
return c.publishJSON(ctx, SubjectInvitationSend, evt, evt.TokenID)
|
||||
}
|
||||
|
||||
func (c *Client) publishJSON(ctx context.Context, subject string, payload any, dedupeKey uuid.UUID) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -38,3 +38,20 @@ type RSVPConfirmed struct {
|
||||
RiskScore *int `json:"risk_score,omitempty"`
|
||||
SubmittedAt time.Time `json:"submitted_at"`
|
||||
}
|
||||
|
||||
// InvitationSend asks the notifier to dispatch a guest invitation email.
|
||||
// Carries everything the email template needs so the worker doesn't have
|
||||
// to re-fetch event/guest details from Postgres on every send.
|
||||
type InvitationSend struct {
|
||||
EventID uuid.UUID `json:"event_id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
GuestName string `json:"guest_name"`
|
||||
GuestEmail string `json:"guest_email"`
|
||||
HostName string `json:"host_name"`
|
||||
EventName string `json:"event_name"`
|
||||
Venue string `json:"venue,omitempty"`
|
||||
EventDate time.Time `json:"event_date"`
|
||||
Link string `json:"link"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package natspub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
)
|
||||
|
||||
type InvitationSendHandler func(ctx context.Context, evt InvitationSend) error
|
||||
|
||||
type InvitationSendSubscriber struct {
|
||||
logger *slog.Logger
|
||||
consumer jetstream.Consumer
|
||||
handler InvitationSendHandler
|
||||
}
|
||||
|
||||
func NewInvitationSendSubscriber(
|
||||
ctx context.Context,
|
||||
c *Client,
|
||||
durable string,
|
||||
handler InvitationSendHandler,
|
||||
logger *slog.Logger,
|
||||
) (*InvitationSendSubscriber, error) {
|
||||
cons, err := c.js.CreateOrUpdateConsumer(ctx, StreamName, jetstream.ConsumerConfig{
|
||||
Durable: durable,
|
||||
Name: durable,
|
||||
FilterSubject: SubjectInvitationSend,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
MaxDeliver: 5,
|
||||
AckWait: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create consumer %s: %w", durable, err)
|
||||
}
|
||||
return &InvitationSendSubscriber{logger: logger, consumer: cons, handler: handler}, nil
|
||||
}
|
||||
|
||||
func (s *InvitationSendSubscriber) Start(ctx context.Context) (jetstream.ConsumeContext, error) {
|
||||
cc, err := s.consumer.Consume(func(msg jetstream.Msg) {
|
||||
var evt InvitationSend
|
||||
if err := json.Unmarshal(msg.Data(), &evt); err != nil {
|
||||
s.logger.Error("decode invitation.send", "err", err)
|
||||
_ = msg.Term()
|
||||
return
|
||||
}
|
||||
hctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.handler(hctx, evt); err != nil {
|
||||
s.logger.Error("handle invitation.send", "err", err)
|
||||
_ = msg.NakWithDelay(5 * time.Second)
|
||||
return
|
||||
}
|
||||
_ = msg.Ack()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("consume: %w", err)
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sesv2"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sesv2/types"
|
||||
)
|
||||
|
||||
// SESConfig is the surface area for picking an SES sender. ConfigurationSet
|
||||
// is optional but recommended in production — it's where bounce + complaint
|
||||
// SNS topics get wired so the webhook handler has something to consume.
|
||||
type SESConfig struct {
|
||||
Region string
|
||||
FromEmail string
|
||||
FromName string
|
||||
ConfigurationSet string
|
||||
PublicBaseURL string // for unsubscribe links in templates
|
||||
}
|
||||
|
||||
// SESEmailSender sends transactional emails (verification + reset for the
|
||||
// auth flows, plus invitation/confirmation/reminder for guests) via Amazon
|
||||
// SESv2. The same client serves both audiences so callers don't end up
|
||||
// with two SES configurations to maintain.
|
||||
type SESEmailSender struct {
|
||||
client *sesv2.Client
|
||||
tpls *Templates
|
||||
from string
|
||||
configSet *string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewSESEmailSender returns a configured SES sender, or an error if the
|
||||
// AWS SDK can't bootstrap. The caller typically constructs this once at
|
||||
// startup and reuses it for the lifetime of the process.
|
||||
func NewSESEmailSender(ctx context.Context, cfg SESConfig, tpls *Templates) (*SESEmailSender, error) {
|
||||
if cfg.FromEmail == "" {
|
||||
return nil, fmt.Errorf("ses: FromEmail required")
|
||||
}
|
||||
if cfg.Region == "" {
|
||||
cfg.Region = "us-east-1"
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(cfg.Region))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ses: load aws config: %w", err)
|
||||
}
|
||||
client := sesv2.NewFromConfig(awsCfg)
|
||||
|
||||
from := cfg.FromEmail
|
||||
if cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
|
||||
}
|
||||
var cs *string
|
||||
if cfg.ConfigurationSet != "" {
|
||||
cs = aws.String(cfg.ConfigurationSet)
|
||||
}
|
||||
|
||||
return &SESEmailSender{
|
||||
client: client,
|
||||
tpls: tpls,
|
||||
from: from,
|
||||
configSet: cs,
|
||||
baseURL: strings.TrimRight(cfg.PublicBaseURL, "/"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- auth.EmailSender implementation ---
|
||||
|
||||
// SendVerification renders the verification template and posts it to SES.
|
||||
func (s *SESEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Verify your GuestGuard email",
|
||||
TmplVerification, map[string]any{
|
||||
"Name": name,
|
||||
"Link": link,
|
||||
})
|
||||
}
|
||||
|
||||
// SendPasswordReset renders the reset template and posts it to SES.
|
||||
func (s *SESEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Reset your GuestGuard password",
|
||||
TmplPasswordReset, map[string]any{
|
||||
"Name": name,
|
||||
"Link": link,
|
||||
"ExpiryHumane": "1 hour",
|
||||
})
|
||||
}
|
||||
|
||||
// SendGuest is used by the notifier worker for invitation / confirmation /
|
||||
// reminder emails — anything addressed at a guest.
|
||||
func (s *SESEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error) {
|
||||
return s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func (s *SESEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error {
|
||||
_, err := s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SESEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
html, text, err := s.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, err := s.client.SendEmail(ctx, &sesv2.SendEmailInput{
|
||||
FromEmailAddress: aws.String(s.from),
|
||||
Destination: &types.Destination{ToAddresses: []string{to}},
|
||||
ConfigurationSetName: s.configSet,
|
||||
Content: &types.EmailContent{
|
||||
Simple: &types.Message{
|
||||
Subject: &types.Content{Data: aws.String(subject), Charset: aws.String("UTF-8")},
|
||||
Body: &types.Body{
|
||||
Html: &types.Content{Data: aws.String(html), Charset: aws.String("UTF-8")},
|
||||
Text: &types.Content{Data: aws.String(text), Charset: aws.String("UTF-8")},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ses: send: %w", err)
|
||||
}
|
||||
if out.MessageId == nil {
|
||||
return "", nil
|
||||
}
|
||||
return *out.MessageId, nil
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// EmailBackend names the chosen email delivery channel for telemetry +
|
||||
// startup logging. Mostly a debugging aid — code paths don't branch on
|
||||
// this value.
|
||||
type EmailBackend string
|
||||
|
||||
const (
|
||||
BackendResend EmailBackend = "resend"
|
||||
BackendSMTP EmailBackend = "smtp"
|
||||
BackendSES EmailBackend = "ses"
|
||||
BackendLog EmailBackend = "log"
|
||||
)
|
||||
|
||||
// EmailSenderConfig collects every email-related env var so the picker
|
||||
// has a single, ordered place to decide which backend wins. Priority is
|
||||
// Resend > SMTP > SES > Log — the first one with non-empty creds is used.
|
||||
type EmailSenderConfig struct {
|
||||
Resend ResendConfig
|
||||
SMTP SMTPConfig
|
||||
SES SESConfig
|
||||
}
|
||||
|
||||
// CombinedEmailSender satisfies both the auth.EmailSender interface (for
|
||||
// verification + reset emails) and GuestEmailDispatcher (for invitation,
|
||||
// confirmation, reminder). One concrete value handles both audiences so
|
||||
// callers don't end up with two configurations.
|
||||
type CombinedEmailSender interface {
|
||||
SendVerification(ctx context.Context, to, name, link string) error
|
||||
SendPasswordReset(ctx context.Context, to, name, link string) error
|
||||
SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error)
|
||||
}
|
||||
|
||||
// PickEmailSender returns the configured email sender + which backend was
|
||||
// chosen. Falls back to a logger stub if nothing is configured, so the
|
||||
// service stays bootable in stripped-down dev environments.
|
||||
func PickEmailSender(ctx context.Context, cfg EmailSenderConfig, tpls *Templates, logger *slog.Logger) (CombinedEmailSender, EmailBackend, error) {
|
||||
switch {
|
||||
case cfg.Resend.APIKey != "":
|
||||
s, err := NewResendEmailSender(cfg.Resend, tpls)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return s, BackendResend, nil
|
||||
case cfg.SMTP.Host != "":
|
||||
s, err := NewSMTPEmailSender(cfg.SMTP, tpls)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return s, BackendSMTP, nil
|
||||
case cfg.SES.FromEmail != "":
|
||||
s, err := NewSESEmailSender(ctx, cfg.SES, tpls)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return s, BackendSES, nil
|
||||
}
|
||||
return &logCombinedSender{logger: logger, tpls: tpls}, BackendLog, nil
|
||||
}
|
||||
|
||||
// logCombinedSender is the dev fallback. Verification + reset emails come
|
||||
// through as structured log lines (preserving the Block A behaviour);
|
||||
// guest emails get rendered + dumped so engineers can eyeball the output.
|
||||
type logCombinedSender struct {
|
||||
logger *slog.Logger
|
||||
tpls *Templates
|
||||
}
|
||||
|
||||
func (l *logCombinedSender) SendVerification(_ context.Context, to, name, link string) error {
|
||||
l.logger.Info("auth email (stub): verification", "to", to, "name", name, "link", link)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logCombinedSender) SendPasswordReset(_ context.Context, to, name, link string) error {
|
||||
l.logger.Info("auth email (stub): password reset", "to", to, "name", name, "link", link)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logCombinedSender) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
_, text, err := l.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
l.logger.Info("guest email (stub)",
|
||||
"to", to, "subject", subject, "template", string(name), "text_body", text,
|
||||
)
|
||||
return "log:" + string(name), nil
|
||||
}
|
||||
@@ -64,12 +64,13 @@ func NewRepo(db *storage.DB) *Repo {
|
||||
}
|
||||
|
||||
type RecordParams struct {
|
||||
GuestID uuid.UUID
|
||||
Channel Channel
|
||||
Type Type
|
||||
Status Status
|
||||
ProviderID string
|
||||
Error string
|
||||
GuestID uuid.UUID
|
||||
Channel Channel
|
||||
Type Type
|
||||
Status Status
|
||||
ProviderID string // human-friendly id (e.g. "log:xyz")
|
||||
ProviderMessageID string // provider's message id (Twilio SID, SES MessageId)
|
||||
Error string
|
||||
}
|
||||
|
||||
func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
@@ -77,6 +78,10 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
if p.ProviderID != "" {
|
||||
providerID = &p.ProviderID
|
||||
}
|
||||
var providerMsgID *string
|
||||
if p.ProviderMessageID != "" {
|
||||
providerMsgID = &p.ProviderMessageID
|
||||
}
|
||||
var errStr *string
|
||||
if p.Error != "" {
|
||||
errStr = &p.Error
|
||||
@@ -90,14 +95,15 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
|
||||
const q = `
|
||||
INSERT INTO notifications (guest_id, channel, type, status, provider_id,
|
||||
attempts, last_attempt, delivered_at, error)
|
||||
VALUES ($1, $2, $3, $4, $5, 1, now(), $6, $7)
|
||||
provider_message_id, attempts, last_attempt,
|
||||
delivered_at, error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 1, now(), $7, $8)
|
||||
RETURNING id
|
||||
`
|
||||
var id uuid.UUID
|
||||
err := r.pool.QueryRow(ctx, q,
|
||||
p.GuestID, string(p.Channel), string(p.Type), string(p.Status),
|
||||
providerID, deliveredAt, errStr,
|
||||
providerID, providerMsgID, deliveredAt, errStr,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("record notification: %w", err)
|
||||
@@ -105,6 +111,35 @@ func (r *Repo) Record(ctx context.Context, p RecordParams) (uuid.UUID, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// MarkBounce records a bounce on the notification row identified by the
|
||||
// provider's message id. Called from webhook handlers.
|
||||
func (r *Repo) MarkBounce(ctx context.Context, providerMessageID, bounceType string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE notifications
|
||||
SET status = 'bounced', bounce_type = $2, error = COALESCE(error, '')
|
||||
WHERE provider_message_id = $1
|
||||
`, providerMessageID, bounceType)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkComplaint records a complaint (spam report) for the same row.
|
||||
func (r *Repo) MarkComplaint(ctx context.Context, providerMessageID string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE notifications SET complained = TRUE WHERE provider_message_id = $1
|
||||
`, providerMessageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkDelivered moves a row from 'sent' to 'delivered' when the provider's
|
||||
// delivery status webhook fires.
|
||||
func (r *Repo) MarkDelivered(ctx context.Context, providerMessageID string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE notifications SET status = 'delivered', delivered_at = now()
|
||||
WHERE provider_message_id = $1 AND status NOT IN ('bounced','failed')
|
||||
`, providerMessageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// LogSender pretends to send and just logs. Useful for Phase 3 demos and
|
||||
// tests; concrete providers (Twilio/SES) plug in later.
|
||||
type LogSender struct{}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResendConfig configures the Resend HTTP sender. APIKey is required;
|
||||
// FromEmail is the sending address on a domain you've verified in the
|
||||
// Resend dashboard.
|
||||
type ResendConfig struct {
|
||||
APIKey string
|
||||
FromEmail string
|
||||
FromName string
|
||||
|
||||
// HTTPClient overrides the default client (mostly so tests can point
|
||||
// at httptest.Server). Leave nil for production.
|
||||
HTTPClient *http.Client
|
||||
BaseURL string // overrideable for tests; defaults to https://api.resend.com
|
||||
}
|
||||
|
||||
// ResendEmailSender posts emails through https://api.resend.com/emails.
|
||||
// Implements auth.EmailSender + GuestEmailDispatcher.
|
||||
type ResendEmailSender struct {
|
||||
cfg ResendConfig
|
||||
tpls *Templates
|
||||
from string
|
||||
client *http.Client
|
||||
url string
|
||||
}
|
||||
|
||||
func NewResendEmailSender(cfg ResendConfig, tpls *Templates) (*ResendEmailSender, error) {
|
||||
if cfg.APIKey == "" {
|
||||
return nil, errors.New("resend: APIKey required")
|
||||
}
|
||||
if cfg.FromEmail == "" {
|
||||
return nil, errors.New("resend: FromEmail required")
|
||||
}
|
||||
from := cfg.FromEmail
|
||||
if cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
|
||||
}
|
||||
cli := cfg.HTTPClient
|
||||
if cli == nil {
|
||||
cli = &http.Client{Timeout: 15 * time.Second}
|
||||
}
|
||||
base := cfg.BaseURL
|
||||
if base == "" {
|
||||
base = "https://api.resend.com"
|
||||
}
|
||||
return &ResendEmailSender{cfg: cfg, tpls: tpls, from: from, client: cli, url: base + "/emails"}, nil
|
||||
}
|
||||
|
||||
// --- auth.EmailSender ---
|
||||
|
||||
func (s *ResendEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
|
||||
_, err := s.sendTemplated(ctx, to, "Verify your GuestGuard email",
|
||||
TmplVerification, map[string]any{"Name": name, "Link": link})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ResendEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
|
||||
_, err := s.sendTemplated(ctx, to, "Reset your GuestGuard password",
|
||||
TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"})
|
||||
return err
|
||||
}
|
||||
|
||||
// --- GuestEmailDispatcher ---
|
||||
|
||||
func (s *ResendEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
return s.sendTemplated(ctx, to, subject, name, data)
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
type resendRequest struct {
|
||||
From string `json:"from"`
|
||||
To []string `json:"to"`
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type resendResponse struct {
|
||||
ID string `json:"id"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ResendEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
html, text, err := s.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(resendRequest{
|
||||
From: s.from,
|
||||
To: []string{to},
|
||||
Subject: subject,
|
||||
HTML: html,
|
||||
Text: text,
|
||||
})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.cfg.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resend: do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
if resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("resend: status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
var parsed resendResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return "", fmt.Errorf("resend: parse: %w", err)
|
||||
}
|
||||
return parsed.ID, nil
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResendSenderHappyPath(t *testing.T) {
|
||||
tpls, err := NewTemplates()
|
||||
if err != nil {
|
||||
t.Fatalf("NewTemplates: %v", err)
|
||||
}
|
||||
|
||||
var gotPath, gotAuth, gotContentType string
|
||||
var gotBody map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotContentType = r.Header.Get("Content-Type")
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(b, &gotBody)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"id": "resend-abc-123"})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
s, err := NewResendEmailSender(ResendConfig{
|
||||
APIKey: "secret-key",
|
||||
FromEmail: "no-reply@example.test",
|
||||
FromName: "GuestGuard",
|
||||
BaseURL: srv.URL,
|
||||
}, tpls)
|
||||
if err != nil {
|
||||
t.Fatalf("NewResendEmailSender: %v", err)
|
||||
}
|
||||
|
||||
id, err := s.SendGuest(context.Background(), "to@example.test", "You're invited",
|
||||
TmplInvitation, map[string]any{
|
||||
"GuestName": "Mira", "HostName": "Kay", "EventName": "Beach Day",
|
||||
"Link": "https://example.test/rsvp/x",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendGuest: %v", err)
|
||||
}
|
||||
if id != "resend-abc-123" {
|
||||
t.Fatalf("provider id: got %q want %q", id, "resend-abc-123")
|
||||
}
|
||||
if gotPath != "/emails" {
|
||||
t.Errorf("path: got %q want /emails", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer secret-key" {
|
||||
t.Errorf("auth: got %q", gotAuth)
|
||||
}
|
||||
if gotContentType != "application/json" {
|
||||
t.Errorf("content-type: got %q", gotContentType)
|
||||
}
|
||||
if gotBody["from"] != "GuestGuard <no-reply@example.test>" {
|
||||
t.Errorf("from: got %v", gotBody["from"])
|
||||
}
|
||||
if !strings.Contains(gotBody["html"].(string), "Beach Day") {
|
||||
t.Errorf("html missing event name: %v", gotBody["html"])
|
||||
}
|
||||
if !strings.Contains(gotBody["text"].(string), "Mira") {
|
||||
t.Errorf("text missing guest name: %v", gotBody["text"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResendSenderErrorPropagates(t *testing.T) {
|
||||
tpls, _ := NewTemplates()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"message":"invalid api key"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
s, err := NewResendEmailSender(ResendConfig{
|
||||
APIKey: "bad", FromEmail: "x@y.test", BaseURL: srv.URL,
|
||||
}, tpls)
|
||||
if err != nil {
|
||||
t.Fatalf("NewResendEmailSender: %v", err)
|
||||
}
|
||||
if err := s.SendVerification(context.Background(), "z@y.test", "Z", "https://link"); err == nil {
|
||||
t.Fatal("expected error on 401")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Router dispatches an OutboundMessage to the channel-appropriate sender.
|
||||
// The notifier worker holds one Router and stays oblivious to whether the
|
||||
// concrete backend is SES, Twilio, or a logger stub.
|
||||
type Router struct {
|
||||
email Sender
|
||||
sms Sender
|
||||
}
|
||||
|
||||
func NewRouter(email, sms Sender) *Router {
|
||||
return &Router{email: email, sms: sms}
|
||||
}
|
||||
|
||||
func (r *Router) Send(ctx context.Context, msg OutboundMessage) (string, error) {
|
||||
switch msg.Channel {
|
||||
case ChannelEmail:
|
||||
if r.email == nil {
|
||||
return "", fmt.Errorf("no email sender configured")
|
||||
}
|
||||
return r.email.Send(ctx, msg)
|
||||
case ChannelSMS:
|
||||
if r.sms == nil {
|
||||
return "", fmt.Errorf("no sms sender configured")
|
||||
}
|
||||
return r.sms.Send(ctx, msg)
|
||||
}
|
||||
return "", fmt.Errorf("router: unknown channel %q", msg.Channel)
|
||||
}
|
||||
|
||||
// LogGuestEmailDispatcher is the dev-mode dispatcher that renders the
|
||||
// templated email and logs both bodies. Useful for local docker-compose
|
||||
// before SES is configured — gives engineers the rendered output without
|
||||
// needing a real inbox.
|
||||
type LogGuestEmailDispatcher struct {
|
||||
logger *slog.Logger
|
||||
tpls *Templates
|
||||
}
|
||||
|
||||
func NewLogGuestEmailDispatcher(logger *slog.Logger, tpls *Templates) *LogGuestEmailDispatcher {
|
||||
return &LogGuestEmailDispatcher{logger: logger, tpls: tpls}
|
||||
}
|
||||
|
||||
func (d *LogGuestEmailDispatcher) SendGuest(_ context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
_, text, err := d.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
id := "log:" + uuid.NewString()
|
||||
d.logger.Info("guest email (stub)",
|
||||
"to", to,
|
||||
"subject", subject,
|
||||
"template", string(name),
|
||||
"provider_message_id", id,
|
||||
"text_body", text,
|
||||
)
|
||||
return id, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GuestEmailDispatcher is the abstraction the notifier uses to send a
|
||||
// templated email to a single guest. Both SES (production) and Log (dev)
|
||||
// satisfy it.
|
||||
type GuestEmailDispatcher interface {
|
||||
SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (providerMessageID string, err error)
|
||||
}
|
||||
|
||||
// EmailSender is the notification.Sender for ChannelEmail. It maps the
|
||||
// generic OutboundMessage shape used by the notifier worker onto our
|
||||
// templated SES / Log path, and honours the suppression list (a guest who
|
||||
// unsubscribed must never receive another email).
|
||||
type EmailSender struct {
|
||||
dispatcher GuestEmailDispatcher
|
||||
suppression *SuppressionRepo
|
||||
}
|
||||
|
||||
func NewEmailSender(d GuestEmailDispatcher, s *SuppressionRepo) *EmailSender {
|
||||
return &EmailSender{dispatcher: d, suppression: s}
|
||||
}
|
||||
|
||||
func (e *EmailSender) Send(ctx context.Context, msg OutboundMessage) (string, error) {
|
||||
if msg.GuestID == uuid.Nil {
|
||||
return "", errors.New("missing guest id")
|
||||
}
|
||||
if msg.Channel != ChannelEmail {
|
||||
return "", fmt.Errorf("EmailSender does not handle channel %q", msg.Channel)
|
||||
}
|
||||
to, _ := msg.Metadata["to"].(string)
|
||||
if to == "" {
|
||||
return "", errors.New("email recipient missing from metadata.to")
|
||||
}
|
||||
|
||||
if e.suppression != nil {
|
||||
yep, err := e.suppression.IsSuppressed(ctx, to)
|
||||
if err != nil {
|
||||
// On lookup error, fail safe by NOT sending — better than
|
||||
// emailing an unsubscribed address.
|
||||
return "", fmt.Errorf("suppression lookup: %w", err)
|
||||
}
|
||||
if yep {
|
||||
return "suppressed:" + to, nil
|
||||
}
|
||||
}
|
||||
|
||||
tmpl := templateForType(msg.Type)
|
||||
if tmpl == "" {
|
||||
return "", fmt.Errorf("no template for type %q", msg.Type)
|
||||
}
|
||||
subject := msg.Subject
|
||||
if subject == "" {
|
||||
subject = defaultSubject(msg.Type, msg.Metadata)
|
||||
}
|
||||
data := map[string]any{}
|
||||
for k, v := range msg.Metadata {
|
||||
data[k] = v
|
||||
}
|
||||
return e.dispatcher.SendGuest(ctx, to, subject, tmpl, data)
|
||||
}
|
||||
|
||||
func templateForType(t Type) TemplateName {
|
||||
switch t {
|
||||
case TypeInvitation:
|
||||
return TmplInvitation
|
||||
case TypeConfirmation:
|
||||
return TmplConfirmation
|
||||
case TypeReminder:
|
||||
return TmplReminder
|
||||
case TypeVerification:
|
||||
return TmplVerification
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func defaultSubject(t Type, meta map[string]any) string {
|
||||
event, _ := meta["EventName"].(string)
|
||||
switch t {
|
||||
case TypeInvitation:
|
||||
if event != "" {
|
||||
return "You're invited — " + event
|
||||
}
|
||||
return "You're invited"
|
||||
case TypeConfirmation:
|
||||
if event != "" {
|
||||
return "RSVP confirmed — " + event
|
||||
}
|
||||
return "RSVP confirmed"
|
||||
case TypeReminder:
|
||||
if event != "" {
|
||||
return "Reminder: " + event
|
||||
}
|
||||
return "Reminder"
|
||||
}
|
||||
return "GuestGuard"
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/twilio/twilio-go"
|
||||
openapi "github.com/twilio/twilio-go/rest/api/v2010"
|
||||
)
|
||||
|
||||
// TwilioConfig configures the SMS sender. AccountSID + AuthToken pair
|
||||
// authenticate the REST client; FromNumber is the verified Twilio number.
|
||||
type TwilioConfig struct {
|
||||
AccountSID string
|
||||
AuthToken string
|
||||
FromNumber string
|
||||
MaxAttempts int
|
||||
}
|
||||
|
||||
// TwilioSender implements notification.Sender for ChannelSMS. Retries on
|
||||
// transient errors with exponential backoff (1s, 5s, 30s, 5m, 30m). Twilio
|
||||
// surfaces a numeric ErrorCode for permanent failures (e.g. 21610
|
||||
// unsubscribed, 21408 disabled region) — those return immediately.
|
||||
type TwilioSender struct {
|
||||
client *twilio.RestClient
|
||||
from string
|
||||
maxAttempt int
|
||||
}
|
||||
|
||||
func NewTwilioSender(cfg TwilioConfig) (*TwilioSender, error) {
|
||||
if cfg.AccountSID == "" || cfg.AuthToken == "" || cfg.FromNumber == "" {
|
||||
return nil, errors.New("twilio: AccountSID / AuthToken / FromNumber are required")
|
||||
}
|
||||
cli := twilio.NewRestClientWithParams(twilio.ClientParams{
|
||||
Username: cfg.AccountSID,
|
||||
Password: cfg.AuthToken,
|
||||
})
|
||||
max := cfg.MaxAttempts
|
||||
if max <= 0 {
|
||||
max = 5
|
||||
}
|
||||
return &TwilioSender{client: cli, from: cfg.FromNumber, maxAttempt: max}, nil
|
||||
}
|
||||
|
||||
func (t *TwilioSender) Send(ctx context.Context, msg OutboundMessage) (string, error) {
|
||||
if msg.GuestID == uuid.Nil {
|
||||
return "", errors.New("missing guest id")
|
||||
}
|
||||
if msg.Channel != ChannelSMS {
|
||||
return "", fmt.Errorf("TwilioSender does not handle channel %q", msg.Channel)
|
||||
}
|
||||
to, _ := msg.Metadata["phone"].(string)
|
||||
if to == "" {
|
||||
return "", errors.New("sms recipient missing from metadata.phone")
|
||||
}
|
||||
body := msg.Body
|
||||
if body == "" {
|
||||
body = msg.Subject
|
||||
}
|
||||
if body == "" {
|
||||
return "", errors.New("sms body is empty")
|
||||
}
|
||||
|
||||
params := &openapi.CreateMessageParams{}
|
||||
params.SetTo(to)
|
||||
params.SetFrom(t.from)
|
||||
params.SetBody(body)
|
||||
|
||||
backoff := []time.Duration{0, time.Second, 5 * time.Second, 30 * time.Second, 5 * time.Minute, 30 * time.Minute}
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < t.maxAttempt; attempt++ {
|
||||
if attempt < len(backoff) && backoff[attempt] > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(backoff[attempt]):
|
||||
}
|
||||
}
|
||||
resp, err := t.client.Api.CreateMessage(params)
|
||||
if err == nil && resp != nil && resp.Sid != nil {
|
||||
return *resp.Sid, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !isTwilioRetryable(err) {
|
||||
return "", fmt.Errorf("twilio: send: %w", err)
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("twilio: send (after %d attempts): %w", t.maxAttempt, lastErr)
|
||||
}
|
||||
|
||||
// isTwilioRetryable returns true for transient failures (network, 5xx).
|
||||
// Twilio's permanent error codes (21xxx range) are not retried.
|
||||
func isTwilioRetryable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
// Cheap heuristic: permanent codes are in the 21xxx range; everything
|
||||
// else (timeouts, 503s, DNS hiccups) is fair game.
|
||||
if strings.Contains(msg, "21610") || strings.Contains(msg, "21408") || strings.Contains(msg, "21211") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,238 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SMTPConfig describes one SMTP relay. Username/Password are optional —
|
||||
// local relays like Mailpit accept anonymous SMTP. TLS modes:
|
||||
//
|
||||
// - "starttls" upgrade after EHLO (most relays, default)
|
||||
// - "implicit" TLS handshake before SMTP (port 465)
|
||||
// - "none" plain socket — only acceptable on a trusted LAN
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromEmail string
|
||||
FromName string
|
||||
TLS string
|
||||
}
|
||||
|
||||
// SMTPEmailSender implements auth.EmailSender (verification/reset) AND
|
||||
// GuestEmailDispatcher (invitation/confirmation/reminder) on top of any
|
||||
// SMTP relay. Used for Mailpit in dev; works against Gmail, Fastmail, etc.
|
||||
// in production if the user prefers plain SMTP over an HTTP API.
|
||||
type SMTPEmailSender struct {
|
||||
cfg SMTPConfig
|
||||
tpls *Templates
|
||||
from string
|
||||
}
|
||||
|
||||
func NewSMTPEmailSender(cfg SMTPConfig, tpls *Templates) (*SMTPEmailSender, error) {
|
||||
if cfg.Host == "" {
|
||||
return nil, errors.New("smtp: Host required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
cfg.Port = 587
|
||||
}
|
||||
if cfg.FromEmail == "" {
|
||||
return nil, errors.New("smtp: FromEmail required")
|
||||
}
|
||||
if cfg.TLS == "" {
|
||||
cfg.TLS = "starttls"
|
||||
}
|
||||
from := cfg.FromEmail
|
||||
if cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", cfg.FromName, cfg.FromEmail)
|
||||
}
|
||||
return &SMTPEmailSender{cfg: cfg, tpls: tpls, from: from}, nil
|
||||
}
|
||||
|
||||
// --- auth.EmailSender ---
|
||||
|
||||
func (s *SMTPEmailSender) SendVerification(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Verify your GuestGuard email",
|
||||
TmplVerification, map[string]any{"Name": name, "Link": link})
|
||||
}
|
||||
|
||||
func (s *SMTPEmailSender) SendPasswordReset(ctx context.Context, to, name, link string) error {
|
||||
return s.sendTemplated(ctx, to, "Reset your GuestGuard password",
|
||||
TmplPasswordReset, map[string]any{"Name": name, "Link": link, "ExpiryHumane": "1 hour"})
|
||||
}
|
||||
|
||||
// --- GuestEmailDispatcher ---
|
||||
|
||||
func (s *SMTPEmailSender) SendGuest(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
return s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func (s *SMTPEmailSender) sendTemplated(ctx context.Context, to, subject string, name TemplateName, data map[string]any) error {
|
||||
_, err := s.sendTemplatedReturnID(ctx, to, subject, name, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SMTPEmailSender) sendTemplatedReturnID(ctx context.Context, to, subject string, name TemplateName, data map[string]any) (string, error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
data["Subject"] = subject
|
||||
html, text, err := s.tpls.Render(name, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
msgID := generateMessageID(s.cfg.FromEmail)
|
||||
body := buildMIMEMessage(mimeMessage{
|
||||
MessageID: msgID,
|
||||
From: s.from,
|
||||
To: to,
|
||||
Subject: subject,
|
||||
Text: text,
|
||||
HTML: html,
|
||||
})
|
||||
if err := s.dial(ctx, []string{to}, body); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return msgID, nil
|
||||
}
|
||||
|
||||
func (s *SMTPEmailSender) dial(ctx context.Context, to []string, body []byte) error {
|
||||
addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(s.cfg.Port))
|
||||
d := &net.Dialer{Timeout: 10 * time.Second}
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
d.Deadline = deadline
|
||||
}
|
||||
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
)
|
||||
switch strings.ToLower(s.cfg.TLS) {
|
||||
case "implicit":
|
||||
conn, err = tls.DialWithDialer(d, "tcp", addr, &tls.Config{ServerName: s.cfg.Host})
|
||||
default:
|
||||
conn, err = d.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp: dial: %w", err)
|
||||
}
|
||||
|
||||
c, err := smtp.NewClient(conn, s.cfg.Host)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("smtp: new client: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if strings.ToLower(s.cfg.TLS) == "starttls" {
|
||||
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||
if err := c.StartTLS(&tls.Config{ServerName: s.cfg.Host}); err != nil {
|
||||
return fmt.Errorf("smtp: starttls: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.cfg.Username != "" {
|
||||
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host)
|
||||
if err := c.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp: auth: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Mail(s.cfg.FromEmail); err != nil {
|
||||
return fmt.Errorf("smtp: MAIL FROM: %w", err)
|
||||
}
|
||||
for _, rcpt := range to {
|
||||
if err := c.Rcpt(rcpt); err != nil {
|
||||
return fmt.Errorf("smtp: RCPT TO %s: %w", rcpt, err)
|
||||
}
|
||||
}
|
||||
w, err := c.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp: DATA: %w", err)
|
||||
}
|
||||
if _, err := w.Write(body); err != nil {
|
||||
_ = w.Close()
|
||||
return fmt.Errorf("smtp: write body: %w", err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return fmt.Errorf("smtp: close body: %w", err)
|
||||
}
|
||||
return c.Quit()
|
||||
}
|
||||
|
||||
type mimeMessage struct {
|
||||
MessageID string
|
||||
From string
|
||||
To string
|
||||
Subject string
|
||||
Text string
|
||||
HTML string
|
||||
}
|
||||
|
||||
// buildMIMEMessage assembles an RFC 5322 message with a multipart/alternative
|
||||
// body so receiving clients pick HTML when they can and fall back to text
|
||||
// otherwise.
|
||||
func buildMIMEMessage(m mimeMessage) []byte {
|
||||
boundary := randomBoundary()
|
||||
var b strings.Builder
|
||||
b.WriteString("Message-ID: <" + m.MessageID + ">\r\n")
|
||||
b.WriteString("Date: " + time.Now().UTC().Format(time.RFC1123Z) + "\r\n")
|
||||
b.WriteString("From: " + m.From + "\r\n")
|
||||
b.WriteString("To: " + m.To + "\r\n")
|
||||
b.WriteString("Subject: " + m.Subject + "\r\n")
|
||||
b.WriteString("MIME-Version: 1.0\r\n")
|
||||
b.WriteString("Content-Type: multipart/alternative; boundary=" + boundary + "\r\n")
|
||||
b.WriteString("\r\n")
|
||||
|
||||
b.WriteString("--" + boundary + "\r\n")
|
||||
b.WriteString("Content-Type: text/plain; charset=UTF-8\r\n")
|
||||
b.WriteString("Content-Transfer-Encoding: 8bit\r\n")
|
||||
b.WriteString("\r\n")
|
||||
b.WriteString(m.Text)
|
||||
if !strings.HasSuffix(m.Text, "\n") {
|
||||
b.WriteString("\r\n")
|
||||
}
|
||||
|
||||
b.WriteString("--" + boundary + "\r\n")
|
||||
b.WriteString("Content-Type: text/html; charset=UTF-8\r\n")
|
||||
b.WriteString("Content-Transfer-Encoding: 8bit\r\n")
|
||||
b.WriteString("\r\n")
|
||||
b.WriteString(m.HTML)
|
||||
if !strings.HasSuffix(m.HTML, "\n") {
|
||||
b.WriteString("\r\n")
|
||||
}
|
||||
|
||||
b.WriteString("--" + boundary + "--\r\n")
|
||||
return []byte(b.String())
|
||||
}
|
||||
|
||||
func randomBoundary() string {
|
||||
var buf [16]byte
|
||||
_, _ = rand.Read(buf[:])
|
||||
return "gg=" + hex.EncodeToString(buf[:])
|
||||
}
|
||||
|
||||
func generateMessageID(from string) string {
|
||||
var buf [12]byte
|
||||
_, _ = rand.Read(buf[:])
|
||||
domain := "guestguard.local"
|
||||
if at := strings.LastIndexByte(from, '@'); at >= 0 && at+1 < len(from) {
|
||||
domain = from[at+1:]
|
||||
}
|
||||
return hex.EncodeToString(buf[:]) + "@" + domain
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildMIMEMessageStructure(t *testing.T) {
|
||||
body := buildMIMEMessage(mimeMessage{
|
||||
MessageID: "abc@example.test",
|
||||
From: "GuestGuard <no-reply@example.test>",
|
||||
To: "to@example.test",
|
||||
Subject: "Verify your GuestGuard email",
|
||||
Text: "Hi Mira, please verify.",
|
||||
HTML: "<p>Hi Mira, please verify.</p>",
|
||||
})
|
||||
s := string(body)
|
||||
checks := []string{
|
||||
"Message-ID: <abc@example.test>",
|
||||
"From: GuestGuard <no-reply@example.test>",
|
||||
"To: to@example.test",
|
||||
"Subject: Verify your GuestGuard email",
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: multipart/alternative; boundary=",
|
||||
"Content-Type: text/plain; charset=UTF-8",
|
||||
"Content-Type: text/html; charset=UTF-8",
|
||||
"Hi Mira, please verify.",
|
||||
"<p>Hi Mira, please verify.</p>",
|
||||
}
|
||||
for _, want := range checks {
|
||||
if !strings.Contains(s, want) {
|
||||
t.Errorf("MIME body missing %q\n---\n%s", want, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageIDIncludesDomain(t *testing.T) {
|
||||
id := generateMessageID("no-reply@example.test")
|
||||
if !strings.HasSuffix(id, "@example.test") {
|
||||
t.Fatalf("message id has wrong domain: %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMessageIDFallback(t *testing.T) {
|
||||
id := generateMessageID("not-an-email")
|
||||
if !strings.HasSuffix(id, "@guestguard.local") {
|
||||
t.Fatalf("expected fallback domain: %s", id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// SuppressionSource categorises why an address landed on the suppression
|
||||
// list. Bounces + complaints come from provider webhooks; "user" is set
|
||||
// when a guest clicks an unsubscribe link.
|
||||
type SuppressionSource string
|
||||
|
||||
const (
|
||||
SuppressionBounce SuppressionSource = "bounce"
|
||||
SuppressionComplaint SuppressionSource = "complaint"
|
||||
SuppressionManual SuppressionSource = "manual"
|
||||
SuppressionUser SuppressionSource = "user"
|
||||
)
|
||||
|
||||
// SuppressionRepo manages the unsubscribes table — a flat suppression list
|
||||
// of email addresses that must never receive another email regardless of
|
||||
// notification type.
|
||||
type SuppressionRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewSuppressionRepo(db *storage.DB) *SuppressionRepo {
|
||||
return &SuppressionRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
// IsSuppressed returns true if `email` is on the suppression list.
|
||||
// Empty / unparseable addresses are treated as not-suppressed; the caller
|
||||
// is expected to validate before sending.
|
||||
func (r *SuppressionRepo) IsSuppressed(ctx context.Context, email string) (bool, error) {
|
||||
email = normaliseEmail(email)
|
||||
if email == "" {
|
||||
return false, nil
|
||||
}
|
||||
var exists bool
|
||||
err := r.pool.QueryRow(ctx,
|
||||
`SELECT EXISTS (SELECT 1 FROM unsubscribes WHERE email = $1)`,
|
||||
email,
|
||||
).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Add records `email` on the suppression list. Idempotent — repeated calls
|
||||
// keep the earliest entry's timestamp.
|
||||
func (r *SuppressionRepo) Add(ctx context.Context, email, reason string, src SuppressionSource) error {
|
||||
email = normaliseEmail(email)
|
||||
if email == "" {
|
||||
return errors.New("empty email")
|
||||
}
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO unsubscribes (email, reason, source)
|
||||
VALUES ($1, NULLIF($2, ''), $3)
|
||||
ON CONFLICT (email) DO NOTHING
|
||||
`, email, reason, string(src))
|
||||
return err
|
||||
}
|
||||
|
||||
// Get returns the suppression record for `email`, or pgx.ErrNoRows if not
|
||||
// found. Mostly used by tests / admin tooling.
|
||||
func (r *SuppressionRepo) Get(ctx context.Context, email string) (string, SuppressionSource, error) {
|
||||
var reason *string
|
||||
var source string
|
||||
err := r.pool.QueryRow(ctx,
|
||||
`SELECT reason, source FROM unsubscribes WHERE email = $1`,
|
||||
normaliseEmail(email),
|
||||
).Scan(&reason, &source)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", err
|
||||
}
|
||||
return "", "", err
|
||||
}
|
||||
r2 := ""
|
||||
if reason != nil {
|
||||
r2 = *reason
|
||||
}
|
||||
return r2, SuppressionSource(source), nil
|
||||
}
|
||||
|
||||
func normaliseEmail(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
htmltemplate "html/template"
|
||||
"io/fs"
|
||||
"path"
|
||||
"strings"
|
||||
texttemplate "text/template"
|
||||
)
|
||||
|
||||
//go:embed templates
|
||||
var templatesFS embed.FS
|
||||
|
||||
// TemplateName names one of the email templates. There must be a
|
||||
// matching `<name>.html` and `<name>.txt` under templates/.
|
||||
type TemplateName string
|
||||
|
||||
const (
|
||||
TmplVerification TemplateName = "verification"
|
||||
TmplPasswordReset TemplateName = "reset"
|
||||
TmplInvitation TemplateName = "invitation"
|
||||
TmplConfirmation TemplateName = "confirmation"
|
||||
TmplReminder TemplateName = "reminder"
|
||||
)
|
||||
|
||||
// Templates renders branded transactional emails for both HTML and
|
||||
// plaintext bodies. Loaded once at construction; thread-safe afterwards.
|
||||
//
|
||||
// Each page-level HTML template gets its own *html/template.Template
|
||||
// holding a copy of the `base.html` partial. This avoids html/template's
|
||||
// per-template contextual-escape pass interfering between pages that all
|
||||
// define a sibling named "body".
|
||||
type Templates struct {
|
||||
html map[string]*htmltemplate.Template // page-name (no ext) -> root
|
||||
text *texttemplate.Template
|
||||
}
|
||||
|
||||
func NewTemplates() (*Templates, error) {
|
||||
root, err := fs.Sub(templatesFS, "templates")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("templates fs: %w", err)
|
||||
}
|
||||
|
||||
baseHTML, err := fs.ReadFile(root, "base.html")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read base.html: %w", err)
|
||||
}
|
||||
|
||||
out := &Templates{
|
||||
html: make(map[string]*htmltemplate.Template),
|
||||
text: texttemplate.New("__root__"),
|
||||
}
|
||||
|
||||
walk := func(p string, d fs.DirEntry, _ error) error {
|
||||
if d == nil || d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
base := path.Base(p)
|
||||
if base == "base.html" {
|
||||
return nil // partial — folded into each page template below
|
||||
}
|
||||
b, err := fs.ReadFile(root, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case strings.HasSuffix(p, ".html"):
|
||||
name := strings.TrimSuffix(base, ".html")
|
||||
t := htmltemplate.New(name)
|
||||
if _, err := t.Parse(string(baseHTML)); err != nil {
|
||||
return fmt.Errorf("parse _base for %s: %w", p, err)
|
||||
}
|
||||
if _, err := t.Parse(string(b)); err != nil {
|
||||
return fmt.Errorf("parse %s: %w", p, err)
|
||||
}
|
||||
out.html[name] = t
|
||||
case strings.HasSuffix(p, ".txt"):
|
||||
if _, err := out.text.New(base).Parse(string(b)); err != nil {
|
||||
return fmt.Errorf("parse %s: %w", p, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := fs.WalkDir(root, ".", walk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Render returns (htmlBody, textBody) for the named template using data.
|
||||
func (t *Templates) Render(name TemplateName, data map[string]any) (htmlBody, textBody string, err error) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
if _, ok := data["Subject"]; !ok {
|
||||
data["Subject"] = "GuestGuard"
|
||||
}
|
||||
|
||||
htmlTpl, ok := t.html[string(name)]
|
||||
if !ok {
|
||||
return "", "", fmt.Errorf("unknown html template %q", name)
|
||||
}
|
||||
|
||||
var hBuf, tBuf bytes.Buffer
|
||||
// Each page-root template's entry point is "base" (defined by base.html).
|
||||
if err := htmlTpl.ExecuteTemplate(&hBuf, "base", data); err != nil {
|
||||
return "", "", fmt.Errorf("render html %s: %w", name, err)
|
||||
}
|
||||
if err := t.text.ExecuteTemplate(&tBuf, string(name)+".txt", data); err != nil {
|
||||
return "", "", fmt.Errorf("render text %s: %w", name, err)
|
||||
}
|
||||
return hBuf.String(), tBuf.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
{{define "base"}}<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>{{.Subject}}</title>
|
||||
</head>
|
||||
<body style="margin:0;padding:0;background:#f7f7f8;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,Helvetica,Arial,sans-serif;color:#0f172a;">
|
||||
<table role="presentation" width="100%" cellpadding="0" cellspacing="0" style="background:#f7f7f8;padding:32px 16px;">
|
||||
<tr><td align="center">
|
||||
<table role="presentation" width="560" cellpadding="0" cellspacing="0" style="background:#ffffff;border-radius:12px;overflow:hidden;box-shadow:0 1px 2px rgba(15,23,42,0.05);">
|
||||
<tr>
|
||||
<td style="background:#0a0a0a;padding:20px 28px;">
|
||||
<span style="display:inline-block;width:10px;height:10px;background:#22c55e;border-radius:50%;vertical-align:middle;"></span>
|
||||
<span style="color:#fafafa;font-weight:600;font-size:16px;margin-left:8px;vertical-align:middle;">GuestGuard</span>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:32px 28px 24px;font-size:15px;line-height:1.55;color:#0f172a;">
|
||||
{{block "body" .}}{{end}}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:0 28px 28px;">
|
||||
<hr style="border:none;border-top:1px solid #e5e7eb;margin:0 0 16px;">
|
||||
<p style="font-size:12px;color:#64748b;margin:0;">
|
||||
You're receiving this because of activity on your GuestGuard account.
|
||||
{{if .UnsubscribeLink}}<br>If you'd rather not get emails like this, <a href="{{.UnsubscribeLink}}" style="color:#16a34a;">unsubscribe here</a>.{{end}}
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td></tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>{{end}}
|
||||
@@ -0,0 +1,11 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">RSVP received</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.GuestName}},</p>
|
||||
<p style="margin:0 0 16px;">Thanks for letting {{.HostName}} know — your RSVP for <strong>{{.EventName}}</strong> is confirmed as <strong>{{.Response}}</strong>{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.</p>
|
||||
{{if or .Venue .EventDate}}
|
||||
<p style="margin:0 0 16px;color:#64748b;font-size:14px;">
|
||||
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
|
||||
</p>
|
||||
{{end}}
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">You'll get a reminder closer to the date. If your plans change, use the same invitation link to update your reply.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,10 @@
|
||||
Hi {{.GuestName}},
|
||||
|
||||
Your RSVP for "{{.EventName}}" is confirmed as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.
|
||||
|
||||
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
|
||||
|
||||
You'll get a reminder closer to the date. If your plans change, use the
|
||||
same invitation link to update your reply.
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,15 @@
|
||||
{{define "body"}}
|
||||
<p style="font-size:13px;letter-spacing:0.2em;text-transform:uppercase;color:#16a34a;margin:0 0 16px;">✦ You're invited</p>
|
||||
<h1 style="font-size:24px;margin:0 0 4px;color:#0a0a0a;">{{.EventName}}</h1>
|
||||
{{if or .Venue .EventDate}}
|
||||
<p style="margin:0 0 24px;color:#64748b;font-size:14px;">
|
||||
{{if .Venue}}{{.Venue}}{{end}}{{if and .Venue .EventDate}} · {{end}}{{if .EventDate}}{{.EventDate}}{{end}}
|
||||
</p>
|
||||
{{end}}
|
||||
<p style="margin:0 0 20px;">Hi {{.GuestName}}, {{.HostName}} would love to know if you can make it. Use the personal link below to RSVP.</p>
|
||||
<p style="margin:0 0 24px;text-align:center;">
|
||||
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">RSVP now</a>
|
||||
</p>
|
||||
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
|
||||
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,11 @@
|
||||
You're invited — {{.EventName}}
|
||||
|
||||
Hi {{.GuestName}},
|
||||
|
||||
{{.HostName}} would love to know if you can make it{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.
|
||||
|
||||
RSVP here:
|
||||
|
||||
{{.Link}}
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,7 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">Reminder: {{.EventName}}</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.GuestName}},</p>
|
||||
<p style="margin:0 0 16px;">Just a quick reminder that <strong>{{.EventName}}</strong> is coming up{{if .EventDate}} on {{.EventDate}}{{end}}{{if .Venue}}, at {{.Venue}}{{end}}.</p>
|
||||
{{if .Response}}<p style="margin:0 0 16px;">You're down as <strong>{{.Response}}</strong>{{if gt .PlusOnes 0}} with {{.PlusOnes}} plus-one{{if ne .PlusOnes 1}}s{{end}}{{end}}.</p>{{end}}
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">Need to change your plans? Use your invitation link.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,7 @@
|
||||
Hi {{.GuestName}},
|
||||
|
||||
Reminder: {{.EventName}}{{if .EventDate}} — {{.EventDate}}{{end}}{{if .Venue}} at {{.Venue}}{{end}}.
|
||||
|
||||
{{if .Response}}You're down as {{.Response}}{{if gt .PlusOnes 0}} (+{{.PlusOnes}}){{end}}.{{end}}
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,11 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">Reset your password</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.Name}},</p>
|
||||
<p style="margin:0 0 20px;">We received a request to reset the password on your GuestGuard account. Tap the button to choose a new one — the link is valid for {{.ExpiryHumane}}.</p>
|
||||
<p style="margin:0 0 24px;text-align:center;">
|
||||
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">Choose a new password</a>
|
||||
</p>
|
||||
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
|
||||
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">If you didn't ask to reset your password, you can ignore this email — your current password is unchanged.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,11 @@
|
||||
Hi {{.Name}},
|
||||
|
||||
We received a request to reset your GuestGuard password. The link is valid
|
||||
for {{.ExpiryHumane}}:
|
||||
|
||||
{{.Link}}
|
||||
|
||||
If you didn't ask to reset your password, you can ignore this email — your
|
||||
current password is unchanged.
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,11 @@
|
||||
{{define "body"}}
|
||||
<h1 style="font-size:20px;margin:0 0 12px;">Verify your email</h1>
|
||||
<p style="margin:0 0 16px;">Hi {{.Name}}, welcome to GuestGuard.</p>
|
||||
<p style="margin:0 0 20px;">To finish setting up your account, please confirm this is your email address.</p>
|
||||
<p style="margin:0 0 24px;text-align:center;">
|
||||
<a href="{{.Link}}" style="background:#22c55e;color:#0a0a0a;padding:12px 22px;border-radius:8px;font-weight:600;text-decoration:none;display:inline-block;">Verify email</a>
|
||||
</p>
|
||||
<p style="margin:0 0 8px;color:#64748b;font-size:13px;">Or paste this URL into your browser:</p>
|
||||
<p style="margin:0;word-break:break-all;font-size:12px;color:#0f172a;">{{.Link}}</p>
|
||||
<p style="margin:24px 0 0;font-size:13px;color:#64748b;">If you didn't sign up for GuestGuard, you can ignore this email.</p>
|
||||
{{end}}
|
||||
@@ -0,0 +1,9 @@
|
||||
Hi {{.Name}}, welcome to GuestGuard.
|
||||
|
||||
Please verify your email by visiting:
|
||||
|
||||
{{.Link}}
|
||||
|
||||
If you didn't sign up for GuestGuard, you can ignore this email.
|
||||
|
||||
— GuestGuard
|
||||
@@ -0,0 +1,101 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRenderAllTemplates(t *testing.T) {
|
||||
tpls, err := NewTemplates()
|
||||
if err != nil {
|
||||
t.Fatalf("NewTemplates: %v", err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name TemplateName
|
||||
data map[string]any
|
||||
wantHTML []string // substrings expected in HTML body
|
||||
wantText []string // substrings expected in text body
|
||||
}{
|
||||
{
|
||||
name: TmplVerification,
|
||||
data: map[string]any{
|
||||
"Name": "Kay",
|
||||
"Link": "https://example.test/verify-email?token=x",
|
||||
"Subject": "Verify your GuestGuard email",
|
||||
"UnsubscribeLink": "https://example.test/unsubscribe/abc",
|
||||
},
|
||||
wantHTML: []string{"Verify your email", "Kay", "Verify email", "https://example.test/verify-email?token=x", "unsubscribe here"},
|
||||
wantText: []string{"Hi Kay", "https://example.test/verify-email?token=x"},
|
||||
},
|
||||
{
|
||||
name: TmplPasswordReset,
|
||||
data: map[string]any{
|
||||
"Name": "Kay",
|
||||
"Link": "https://example.test/reset-password/abc",
|
||||
"ExpiryHumane": "1 hour",
|
||||
},
|
||||
wantHTML: []string{"Reset your password", "1 hour", "https://example.test/reset-password/abc"},
|
||||
wantText: []string{"reset your GuestGuard password", "1 hour"},
|
||||
},
|
||||
{
|
||||
name: TmplInvitation,
|
||||
data: map[string]any{
|
||||
"GuestName": "Mira",
|
||||
"HostName": "Kay",
|
||||
"EventName": "Beach Day",
|
||||
"Venue": "Ocean Park",
|
||||
"EventDate": "Sat 14 Jun, 4pm",
|
||||
"Link": "https://example.test/rsvp/tok_x",
|
||||
},
|
||||
wantHTML: []string{"You're invited", "Beach Day", "Mira", "Ocean Park", "RSVP now", "https://example.test/rsvp/tok_x"},
|
||||
wantText: []string{"Beach Day", "Mira", "RSVP here", "https://example.test/rsvp/tok_x"},
|
||||
},
|
||||
{
|
||||
name: TmplConfirmation,
|
||||
data: map[string]any{
|
||||
"GuestName": "Mira",
|
||||
"HostName": "Kay",
|
||||
"EventName": "Beach Day",
|
||||
"Venue": "Ocean Park",
|
||||
"EventDate": "Sat 14 Jun, 4pm",
|
||||
"Response": "attending",
|
||||
"PlusOnes": 2,
|
||||
},
|
||||
wantHTML: []string{"RSVP received", "Beach Day", "attending", "2 plus-ones"},
|
||||
wantText: []string{"Beach Day", "attending", "+2"},
|
||||
},
|
||||
{
|
||||
name: TmplReminder,
|
||||
data: map[string]any{
|
||||
"GuestName": "Mira",
|
||||
"EventName": "Beach Day",
|
||||
"EventDate": "Sat 14 Jun, 4pm",
|
||||
"Venue": "Ocean Park",
|
||||
"Response": "attending",
|
||||
"PlusOnes": 1,
|
||||
},
|
||||
wantHTML: []string{"Reminder", "Beach Day", "Ocean Park", "1 plus-one"},
|
||||
wantText: []string{"Reminder", "Beach Day", "(+1)"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(string(tc.name), func(t *testing.T) {
|
||||
html, text, err := tpls.Render(tc.name, tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("render: %v", err)
|
||||
}
|
||||
for _, s := range tc.wantHTML {
|
||||
if !strings.Contains(html, s) {
|
||||
t.Errorf("html missing %q\n---\n%s", s, html)
|
||||
}
|
||||
}
|
||||
for _, s := range tc.wantText {
|
||||
if !strings.Contains(text, s) {
|
||||
t.Errorf("text missing %q\n---\n%s", s, text)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UnsubscribeSigner mints + verifies tamper-proof unsubscribe tokens. Each
|
||||
// token encodes the email address and an HMAC-SHA256 of the address under
|
||||
// a server-side secret. The token has no TTL — unsubscribe links should
|
||||
// keep working forever.
|
||||
//
|
||||
// Token shape: base64url(email) + "." + base64url(hmac)
|
||||
type UnsubscribeSigner struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func NewUnsubscribeSigner(secret string) *UnsubscribeSigner {
|
||||
return &UnsubscribeSigner{secret: []byte(secret)}
|
||||
}
|
||||
|
||||
// Sign returns a URL-safe token for the email address. Empty input → empty
|
||||
// token (caller should validate input first).
|
||||
func (s *UnsubscribeSigner) Sign(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
email = normaliseEmail(email)
|
||||
mac := hmac.New(sha256.New, s.secret)
|
||||
mac.Write([]byte(email))
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(email)) +
|
||||
"." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// Verify decodes the token and confirms the HMAC matches, returning the
|
||||
// owning email address.
|
||||
func (s *UnsubscribeSigner) Verify(token string) (string, error) {
|
||||
dot := strings.IndexByte(token, '.')
|
||||
if dot < 0 {
|
||||
return "", errors.New("malformed unsubscribe token")
|
||||
}
|
||||
emailB, err := base64.RawURLEncoding.DecodeString(token[:dot])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sigB, err := base64.RawURLEncoding.DecodeString(token[dot+1:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
mac := hmac.New(sha256.New, s.secret)
|
||||
mac.Write(emailB)
|
||||
want := mac.Sum(nil)
|
||||
if !hmac.Equal(sigB, want) {
|
||||
return "", errors.New("unsubscribe signature mismatch")
|
||||
}
|
||||
return string(emailB), nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KeyFunc derives the rate-limit key from a request — e.g. IP, IP+email,
|
||||
// authenticated user id, path token, etc. An empty return value bypasses
|
||||
// the limiter (handy when a key isn't available yet, like an email that
|
||||
// hasn't been parsed out of the body).
|
||||
type KeyFunc func(r *http.Request) string
|
||||
|
||||
// Rule names a single sliding-window budget.
|
||||
type Rule struct {
|
||||
Name string
|
||||
Limit int
|
||||
Window time.Duration
|
||||
}
|
||||
|
||||
// Middleware returns an http middleware that applies `rule` to incoming
|
||||
// requests using `keyFn`. On limit, it writes 429 with a Retry-After header
|
||||
// and a JSON body. If the limiter itself errors, requests are allowed
|
||||
// (fail-open) — degraded protection is better than total outage.
|
||||
func (l *Limiter) Middleware(rule Rule, keyFn KeyFunc, logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := keyFn(r)
|
||||
if key == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
res, err := l.Allow(r.Context(), rule.Name, key, rule.Limit, rule.Window)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("ratelimit error (failing open)",
|
||||
"rule", rule.Name, "err", err)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if !res.Allowed {
|
||||
retrySecs := int(res.RetryAfter.Round(time.Second).Seconds())
|
||||
if retrySecs < 1 {
|
||||
retrySecs = 1
|
||||
}
|
||||
w.Header().Set("Retry-After", strconv.Itoa(retrySecs))
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": "rate limit exceeded",
|
||||
"retry_after": retrySecs,
|
||||
})
|
||||
if logger != nil {
|
||||
logger.Info("ratelimit blocked",
|
||||
"rule", rule.Name,
|
||||
"key", key,
|
||||
"retry_after", retrySecs,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
// Package ratelimit implements a sliding-window rate limiter backed by
|
||||
// Redis sorted sets. Each call is atomic via a Lua script: it sweeps
|
||||
// entries older than the window, returns the current count, and (when
|
||||
// under the limit) records the new hit. Block C's "INCR + EXPIRE or a Lua
|
||||
// script for atomicity" requirement.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Result is the outcome of one Allow check.
|
||||
type Result struct {
|
||||
Allowed bool
|
||||
Count int // current count within the window (post-increment when allowed)
|
||||
Limit int
|
||||
RetryAfter time.Duration // populated when Allowed=false
|
||||
}
|
||||
|
||||
// Limiter checks rate-limit budgets against Redis.
|
||||
type Limiter struct {
|
||||
client *redis.Client
|
||||
script *redis.Script
|
||||
prefix string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// New builds a Limiter against the given Redis client. The prefix namespaces
|
||||
// all keys (defaults to "rl").
|
||||
func New(client *redis.Client, prefix string) *Limiter {
|
||||
if prefix == "" {
|
||||
prefix = "rl"
|
||||
}
|
||||
return &Limiter{
|
||||
client: client,
|
||||
script: redis.NewScript(slidingWindowScript),
|
||||
prefix: prefix,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow consumes one unit of budget under (name, key) against `limit` events
|
||||
// per `window`. Returns Allowed=true and the new count, or Allowed=false
|
||||
// with RetryAfter set to roughly the duration until the oldest hit ages out.
|
||||
func (l *Limiter) Allow(ctx context.Context, name, key string, limit int, window time.Duration) (Result, error) {
|
||||
if limit <= 0 {
|
||||
return Result{}, errors.New("ratelimit: limit must be positive")
|
||||
}
|
||||
if window <= 0 {
|
||||
return Result{}, errors.New("ratelimit: window must be positive")
|
||||
}
|
||||
member, err := randomMember()
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
now := l.now().UnixMilli()
|
||||
windowMS := window.Milliseconds()
|
||||
redisKey := fmt.Sprintf("%s:%s:%s", l.prefix, name, key)
|
||||
|
||||
out, err := l.script.Run(ctx, l.client,
|
||||
[]string{redisKey},
|
||||
now, windowMS, limit, member,
|
||||
).Int64Slice()
|
||||
if err != nil {
|
||||
return Result{}, fmt.Errorf("ratelimit: redis: %w", err)
|
||||
}
|
||||
if len(out) != 3 {
|
||||
return Result{}, fmt.Errorf("ratelimit: bad lua reply: %v", out)
|
||||
}
|
||||
|
||||
r := Result{
|
||||
Count: int(out[1]),
|
||||
Limit: limit,
|
||||
}
|
||||
if out[0] == 0 {
|
||||
r.Allowed = true
|
||||
} else {
|
||||
r.RetryAfter = time.Duration(out[2]) * time.Millisecond
|
||||
if r.RetryAfter <= 0 {
|
||||
r.RetryAfter = time.Second
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func randomMember() (string, error) {
|
||||
var buf [12]byte
|
||||
if _, err := rand.Read(buf[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf[:]), nil
|
||||
}
|
||||
|
||||
// Sliding-window check + record, atomic in Redis.
|
||||
//
|
||||
// KEYS[1] = bucket key
|
||||
// ARGV[1] = now (unix ms)
|
||||
// ARGV[2] = window (ms)
|
||||
// ARGV[3] = limit
|
||||
// ARGV[4] = unique member to insert when allowed
|
||||
// returns { blocked, count, retryAfterMs }
|
||||
const slidingWindowScript = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local limit = tonumber(ARGV[3])
|
||||
local member = ARGV[4]
|
||||
local cutoff = now - window
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
|
||||
local count = redis.call('ZCARD', key)
|
||||
|
||||
if count >= limit then
|
||||
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||
local retry = window
|
||||
if oldest[2] then
|
||||
retry = (tonumber(oldest[2]) + window) - now
|
||||
if retry < 1 then retry = 1 end
|
||||
end
|
||||
return {1, count, retry}
|
||||
end
|
||||
|
||||
redis.call('ZADD', key, now, member)
|
||||
redis.call('PEXPIRE', key, window)
|
||||
return {0, count + 1, 0}
|
||||
`
|
||||
@@ -0,0 +1,110 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func newTestLimiter(t *testing.T) (*Limiter, *miniredis.Miniredis) {
|
||||
t.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("miniredis: %v", err)
|
||||
}
|
||||
t.Cleanup(mr.Close)
|
||||
cli := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
t.Cleanup(func() { _ = cli.Close() })
|
||||
return New(cli, "test"), mr
|
||||
}
|
||||
|
||||
func TestLimiterAllowsBelowLimit(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
for i := 1; i <= 3; i++ {
|
||||
r, err := l.Allow(ctx, "signup", "1.2.3.4", 3, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("allow #%d: %v", i, err)
|
||||
}
|
||||
if !r.Allowed {
|
||||
t.Fatalf("hit %d should be allowed, got %+v", i, r)
|
||||
}
|
||||
if r.Count != i {
|
||||
t.Fatalf("hit %d count: got %d want %d", i, r.Count, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterBlocksAtLimit(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := l.Allow(ctx, "signup", "ip", 3, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
r, err := l.Allow(ctx, "signup", "ip", 3, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.Allowed {
|
||||
t.Fatalf("4th hit should be blocked: %+v", r)
|
||||
}
|
||||
if r.RetryAfter <= 0 || r.RetryAfter > time.Minute {
|
||||
t.Fatalf("retry-after out of range: %v", r.RetryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterWindowSlides(t *testing.T) {
|
||||
l, mr := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
// Inject a controllable clock so we can advance time in miniredis +
|
||||
// the limiter consistently.
|
||||
base := time.Unix(1_700_000_000, 0)
|
||||
l.now = func() time.Time { return base }
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
// Slide past the window. miniredis honours TTLs we already set so
|
||||
// FastForward is the trustworthy primitive here.
|
||||
l.now = func() time.Time { return base.Add(2 * time.Minute) }
|
||||
mr.FastForward(2 * time.Minute)
|
||||
|
||||
r, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !r.Allowed {
|
||||
t.Fatalf("expected allow after window: %+v", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterIsolatesKeys(t *testing.T) {
|
||||
l, _ := newTestLimiter(t)
|
||||
ctx := context.Background()
|
||||
// Exhaust budget for one key — others should not be affected.
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
blocked, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if blocked.Allowed {
|
||||
t.Fatal("key a should be blocked")
|
||||
}
|
||||
other, err := l.Allow(ctx, "login", "b@x.com", 2, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !other.Allowed {
|
||||
t.Fatalf("unrelated key b should still be allowed: %+v", other)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
)
|
||||
|
||||
// EmailVerificationRepo manages single-use email verification tokens.
|
||||
type EmailVerificationRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewEmailVerificationRepo(db *DB) *EmailVerificationRepo {
|
||||
return &EmailVerificationRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
func (r *EmailVerificationRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO email_verification_tokens (token_hash, user_id, expires_at)
|
||||
VALUES ($1, $2, $3)
|
||||
`, hash, userID, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// Consume atomically marks the token as used and returns the owning user_id.
|
||||
// Returns ErrAuthTokenNotFound / ErrAuthTokenConsumed / ErrAuthTokenExpired.
|
||||
func (r *EmailVerificationRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) {
|
||||
const q = `
|
||||
UPDATE email_verification_tokens
|
||||
SET consumed_at = now()
|
||||
WHERE token_hash = $1
|
||||
AND consumed_at IS NULL
|
||||
AND expires_at > now()
|
||||
RETURNING user_id
|
||||
`
|
||||
var uid uuid.UUID
|
||||
if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool,
|
||||
"SELECT consumed_at, expires_at FROM email_verification_tokens WHERE token_hash=$1",
|
||||
hash)
|
||||
}
|
||||
return uuid.Nil, err
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// PasswordResetRepo manages single-use password-reset tokens.
|
||||
type PasswordResetRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPasswordResetRepo(db *DB) *PasswordResetRepo {
|
||||
return &PasswordResetRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
func (r *PasswordResetRepo) Create(ctx context.Context, userID uuid.UUID, hash string, expiresAt time.Time) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO password_reset_tokens (token_hash, user_id, expires_at)
|
||||
VALUES ($1, $2, $3)
|
||||
`, hash, userID, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PasswordResetRepo) Consume(ctx context.Context, hash string) (uuid.UUID, error) {
|
||||
const q = `
|
||||
UPDATE password_reset_tokens
|
||||
SET consumed_at = now()
|
||||
WHERE token_hash = $1
|
||||
AND consumed_at IS NULL
|
||||
AND expires_at > now()
|
||||
RETURNING user_id
|
||||
`
|
||||
var uid uuid.UUID
|
||||
if err := r.pool.QueryRow(ctx, q, hash).Scan(&uid); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return uuid.Nil, classifyAuthTokenLookup(ctx, r.pool,
|
||||
"SELECT consumed_at, expires_at FROM password_reset_tokens WHERE token_hash=$1",
|
||||
hash)
|
||||
}
|
||||
return uuid.Nil, err
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// RefreshTokenRepo manages refresh-token rows. Refresh tokens are rotated:
|
||||
// every refresh issues a new token and revokes the old one, recording the
|
||||
// chain in `replaced_by` so we can detect replay (a revoked token being
|
||||
// presented again triggers a family-wide revocation).
|
||||
type RefreshTokenRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewRefreshTokenRepo(db *DB) *RefreshTokenRepo {
|
||||
return &RefreshTokenRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
type RefreshToken struct {
|
||||
Hash string
|
||||
UserID uuid.UUID
|
||||
ExpiresAt time.Time
|
||||
RevokedAt *time.Time
|
||||
ReplacedBy *string
|
||||
UserAgent string
|
||||
IPAddress *netip.Addr
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type CreateRefreshTokenParams struct {
|
||||
Hash string
|
||||
UserID uuid.UUID
|
||||
ExpiresAt time.Time
|
||||
UserAgent string
|
||||
IPAddress string
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) Create(ctx context.Context, p CreateRefreshTokenParams) error {
|
||||
ip := parseIP(p.IPAddress)
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, NULLIF($4, ''), $5)
|
||||
`, p.Hash, p.UserID, p.ExpiresAt, p.UserAgent, ip)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) Get(ctx context.Context, hash string) (*RefreshToken, error) {
|
||||
const q = `
|
||||
SELECT token_hash, user_id, expires_at, revoked_at, replaced_by,
|
||||
COALESCE(user_agent, ''), host(ip_address), created_at
|
||||
FROM refresh_tokens WHERE token_hash = $1
|
||||
`
|
||||
var rt RefreshToken
|
||||
var ipText *string
|
||||
if err := r.pool.QueryRow(ctx, q, hash).Scan(
|
||||
&rt.Hash, &rt.UserID, &rt.ExpiresAt, &rt.RevokedAt, &rt.ReplacedBy,
|
||||
&rt.UserAgent, &ipText, &rt.CreatedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ipText != nil && *ipText != "" {
|
||||
if addr, err := netip.ParseAddr(*ipText); err == nil {
|
||||
rt.IPAddress = &addr
|
||||
}
|
||||
}
|
||||
return &rt, nil
|
||||
}
|
||||
|
||||
// Rotate atomically (in a transaction) marks the old token revoked and
|
||||
// inserts the new one with replaced_by set. Returns ErrAuthTokenNotFound or
|
||||
// ErrRefreshTokenRevoked if the old token is missing or already revoked.
|
||||
func (r *RefreshTokenRepo) Rotate(ctx context.Context, oldHash string, next CreateRefreshTokenParams) error {
|
||||
tx, err := r.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
var revokedAt *time.Time
|
||||
var userID uuid.UUID
|
||||
var expiresAt time.Time
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT user_id, expires_at, revoked_at
|
||||
FROM refresh_tokens WHERE token_hash = $1 FOR UPDATE
|
||||
`, oldHash).Scan(&userID, &expiresAt, &revokedAt)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if revokedAt != nil {
|
||||
// Replay of a revoked refresh token — revoke the entire family.
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
`, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return domain.ErrRefreshTokenRevoked
|
||||
}
|
||||
if time.Now().After(expiresAt) {
|
||||
return domain.ErrAuthTokenExpired
|
||||
}
|
||||
if next.UserID != userID {
|
||||
return errors.New("refresh token user mismatch")
|
||||
}
|
||||
|
||||
ip := parseIP(next.IPAddress)
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO refresh_tokens (token_hash, user_id, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, NULLIF($4, ''), $5)
|
||||
`, next.Hash, next.UserID, next.ExpiresAt, next.UserAgent, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now(), replaced_by = $2
|
||||
WHERE token_hash = $1
|
||||
`, oldHash, next.Hash); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) Revoke(ctx context.Context, hash string) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now()
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL
|
||||
`, hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRepo) RevokeAllForUser(ctx context.Context, userID uuid.UUID) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
UPDATE refresh_tokens SET revoked_at = now()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func parseIP(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
func classifyAuthTokenLookup(ctx context.Context, pool *pgxpool.Pool, q, hash string) error {
|
||||
var consumedAt *time.Time
|
||||
var expiresAt time.Time
|
||||
if err := pool.QueryRow(ctx, q, hash).Scan(&consumedAt, &expiresAt); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if consumedAt != nil {
|
||||
return domain.ErrAuthTokenConsumed
|
||||
}
|
||||
if time.Now().After(expiresAt) {
|
||||
return domain.ErrAuthTokenExpired
|
||||
}
|
||||
return domain.ErrAuthTokenNotFound
|
||||
}
|
||||
+30
-12
@@ -78,6 +78,24 @@ func (r *EventRepo) Get(ctx context.Context, id uuid.UUID) (*domain.Event, error
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
// GetForHost is the authz-aware variant of Get. It returns ErrEventNotFound
|
||||
// when the event either doesn't exist or doesn't belong to the host — by
|
||||
// merging both cases we avoid leaking existence on cross-tenant lookups.
|
||||
func (r *EventRepo) GetForHost(ctx context.Context, id, hostID uuid.UUID) (*domain.Event, error) {
|
||||
const q = `
|
||||
SELECT id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at
|
||||
FROM events WHERE id = $1 AND host_id = $2
|
||||
`
|
||||
ev, err := scanEvent(r.pool.QueryRow(ctx, q, id, hostID))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrEventNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *EventRepo) List(ctx context.Context, hostID uuid.UUID, limit, offset int) ([]*domain.Event, error) {
|
||||
if limit <= 0 || limit > 200 {
|
||||
limit = 50
|
||||
@@ -132,18 +150,18 @@ type UpdateEventParams struct {
|
||||
Status *domain.EventStatus
|
||||
}
|
||||
|
||||
func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParams) (*domain.Event, error) {
|
||||
func (r *EventRepo) Update(ctx context.Context, id, hostID uuid.UUID, p UpdateEventParams) (*domain.Event, error) {
|
||||
const q = `
|
||||
UPDATE events SET
|
||||
name = COALESCE($2, name),
|
||||
slug = COALESCE($3, slug),
|
||||
event_date = COALESCE($4, event_date),
|
||||
venue = COALESCE($5, venue),
|
||||
max_capacity = COALESCE($6, max_capacity),
|
||||
settings = COALESCE($7, settings),
|
||||
status = COALESCE($8, status),
|
||||
name = COALESCE($3, name),
|
||||
slug = COALESCE($4, slug),
|
||||
event_date = COALESCE($5, event_date),
|
||||
venue = COALESCE($6, venue),
|
||||
max_capacity = COALESCE($7, max_capacity),
|
||||
settings = COALESCE($8, settings),
|
||||
status = COALESCE($9, status),
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND host_id = $2
|
||||
RETURNING id, host_id, name, slug, event_date, venue, max_capacity, settings, status, created_at, updated_at
|
||||
`
|
||||
|
||||
@@ -156,7 +174,7 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam
|
||||
settingsJSON = b
|
||||
}
|
||||
|
||||
row := r.pool.QueryRow(ctx, q, id,
|
||||
row := r.pool.QueryRow(ctx, q, id, hostID,
|
||||
p.Name, p.Slug, p.EventDate, p.Venue, p.MaxCapacity, settingsJSON, p.Status,
|
||||
)
|
||||
ev, err := scanEvent(row)
|
||||
@@ -173,8 +191,8 @@ func (r *EventRepo) Update(ctx context.Context, id uuid.UUID, p UpdateEventParam
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *EventRepo) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1`, id)
|
||||
func (r *EventRepo) Delete(ctx context.Context, id, hostID uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `DELETE FROM events WHERE id = $1 AND host_id = $2`, id, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package storage
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -88,6 +90,87 @@ func (r *GuestRepo) ListByEvent(ctx context.Context, eventID uuid.UUID, limit, o
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// BulkImportRow is one normalised guest in a CSV import batch.
|
||||
type BulkImportRow struct {
|
||||
Name string
|
||||
Email string // empty if absent
|
||||
Phone string // empty if absent
|
||||
PlusOnes int
|
||||
}
|
||||
|
||||
// BulkImportResult reports the outcome of a single import call. The
|
||||
// SkippedEmails slice records the addresses we silently dropped because a
|
||||
// guest already exists on the event — useful for the success summary.
|
||||
type BulkImportResult struct {
|
||||
Added int
|
||||
Skipped int
|
||||
SkippedEmails []string
|
||||
}
|
||||
|
||||
// BulkImportGuests inserts up to len(rows) guest rows into the event in a
|
||||
// single transaction. Rows whose email matches an existing guest on the
|
||||
// event are skipped (idempotent re-imports). Within the batch, duplicate
|
||||
// emails after the first are also skipped. Either the entire batch
|
||||
// commits or none of it does.
|
||||
//
|
||||
// Empty email is treated as "no email" and not deduped — those rows
|
||||
// always insert (the host might be entering phone-only guests).
|
||||
func (r *GuestRepo) BulkImportGuests(ctx context.Context, eventID uuid.UUID, rows []BulkImportRow) (BulkImportResult, error) {
|
||||
res := BulkImportResult{}
|
||||
if len(rows) == 0 {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
tx, err := r.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Fetch existing emails on the event into a set for O(1) dedup.
|
||||
existing := map[string]struct{}{}
|
||||
exRows, err := tx.Query(ctx,
|
||||
`SELECT lower(email) FROM guests WHERE event_id = $1 AND email IS NOT NULL AND email <> ''`,
|
||||
eventID)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("load existing emails: %w", err)
|
||||
}
|
||||
for exRows.Next() {
|
||||
var e string
|
||||
if err := exRows.Scan(&e); err != nil {
|
||||
exRows.Close()
|
||||
return res, err
|
||||
}
|
||||
existing[e] = struct{}{}
|
||||
}
|
||||
exRows.Close()
|
||||
|
||||
const ins = `
|
||||
INSERT INTO guests (event_id, name, email, phone, plus_ones)
|
||||
VALUES ($1, $2, NULLIF($3, ''), NULLIF($4, ''), $5)
|
||||
`
|
||||
for _, row := range rows {
|
||||
email := strings.ToLower(strings.TrimSpace(row.Email))
|
||||
if email != "" {
|
||||
if _, dup := existing[email]; dup {
|
||||
res.Skipped++
|
||||
res.SkippedEmails = append(res.SkippedEmails, email)
|
||||
continue
|
||||
}
|
||||
existing[email] = struct{}{}
|
||||
}
|
||||
if _, err := tx.Exec(ctx, ins, eventID, row.Name, email, row.Phone, row.PlusOnes); err != nil {
|
||||
return BulkImportResult{}, fmt.Errorf("insert guest %q: %w", row.Name, err)
|
||||
}
|
||||
res.Added++
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return BulkImportResult{}, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func scanGuest(s rowScanner) (*domain.Guest, error) {
|
||||
var g domain.Guest
|
||||
err := s.Scan(
|
||||
@@ -100,6 +183,135 @@ func scanGuest(s rowScanner) (*domain.Guest, error) {
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
// UpdateGuestParams patches a guest. Nil fields are left untouched.
|
||||
// An empty string for Email / Phone clears the column to NULL, matching
|
||||
// the frontend "clear this field" UX.
|
||||
type UpdateGuestParams struct {
|
||||
Name *string
|
||||
Email *string
|
||||
Phone *string
|
||||
PlusOnes *int
|
||||
}
|
||||
|
||||
// Update applies the patch to (guestID, eventID). Event scoping in the
|
||||
// WHERE clause prevents a host from patching guests on another host's
|
||||
// event even if they guess the guest_id. Returns ErrGuestNotFound when
|
||||
// the guest doesn't exist on the event.
|
||||
func (r *GuestRepo) Update(ctx context.Context, eventID, guestID uuid.UUID, p UpdateGuestParams) (*domain.Guest, error) {
|
||||
sets := []string{}
|
||||
args := []any{guestID, eventID}
|
||||
add := func(col string, val any) {
|
||||
args = append(args, val)
|
||||
sets = append(sets, fmt.Sprintf("%s = $%d", col, len(args)))
|
||||
}
|
||||
if p.Name != nil {
|
||||
add("name", strings.TrimSpace(*p.Name))
|
||||
}
|
||||
if p.Email != nil {
|
||||
if strings.TrimSpace(*p.Email) == "" {
|
||||
sets = append(sets, "email = NULL")
|
||||
} else {
|
||||
add("email", strings.ToLower(strings.TrimSpace(*p.Email)))
|
||||
}
|
||||
}
|
||||
if p.Phone != nil {
|
||||
if strings.TrimSpace(*p.Phone) == "" {
|
||||
sets = append(sets, "phone = NULL")
|
||||
} else {
|
||||
add("phone", strings.TrimSpace(*p.Phone))
|
||||
}
|
||||
}
|
||||
if p.PlusOnes != nil {
|
||||
add("plus_ones", *p.PlusOnes)
|
||||
}
|
||||
if len(sets) == 0 {
|
||||
return r.Get(ctx, guestID)
|
||||
}
|
||||
q := fmt.Sprintf(`
|
||||
UPDATE guests SET %s
|
||||
WHERE id = $1 AND event_id = $2
|
||||
RETURNING id, event_id, name, email, phone, plus_ones, dietary_notes, table_number, created_at
|
||||
`, strings.Join(sets, ", "))
|
||||
g, err := scanGuest(r.pool.QueryRow(ctx, q, args...))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrGuestNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// Delete removes a guest from an event. Cascade-deletes any tokens,
|
||||
// rsvps, access_logs, and notifications tied to the guest. Event scoping
|
||||
// in the WHERE clause stops cross-tenant deletes.
|
||||
func (r *GuestRepo) Delete(ctx context.Context, eventID, guestID uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx,
|
||||
`DELETE FROM guests WHERE id = $1 AND event_id = $2`,
|
||||
guestID, eventID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrGuestNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GuestForInvitation is the minimum data the bulk-invite path needs about
|
||||
// each candidate. Pulled in a single query joined against tokens so the
|
||||
// caller knows up-front who's already received an invitation.
|
||||
type GuestForInvitation struct {
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Email string // empty when the guest has no email on file
|
||||
HasToken bool
|
||||
}
|
||||
|
||||
// ListGuestsForInvitation returns every guest on `eventID`, joined with
|
||||
// the tokens table so the caller can skip guests that already have one.
|
||||
// When `onlyIDs` is non-nil/non-empty, the result is filtered to that
|
||||
// subset (used for explicit selection in the bulk-send UI).
|
||||
func (r *GuestRepo) ListGuestsForInvitation(ctx context.Context, eventID uuid.UUID, onlyIDs []uuid.UUID) ([]GuestForInvitation, error) {
|
||||
var (
|
||||
rows pgx.Rows
|
||||
err error
|
||||
)
|
||||
if len(onlyIDs) == 0 {
|
||||
rows, err = r.pool.Query(ctx, `
|
||||
SELECT g.id, g.name, COALESCE(g.email,''),
|
||||
(t.id IS NOT NULL) AS has_token
|
||||
FROM guests g
|
||||
LEFT JOIN tokens t ON t.guest_id = g.id
|
||||
WHERE g.event_id = $1
|
||||
ORDER BY g.created_at ASC
|
||||
`, eventID)
|
||||
} else {
|
||||
rows, err = r.pool.Query(ctx, `
|
||||
SELECT g.id, g.name, COALESCE(g.email,''),
|
||||
(t.id IS NOT NULL) AS has_token
|
||||
FROM guests g
|
||||
LEFT JOIN tokens t ON t.guest_id = g.id
|
||||
WHERE g.event_id = $1 AND g.id = ANY($2)
|
||||
ORDER BY g.created_at ASC
|
||||
`, eventID, onlyIDs)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []GuestForInvitation
|
||||
for rows.Next() {
|
||||
var g GuestForInvitation
|
||||
if err := rows.Scan(&g.ID, &g.Name, &g.Email, &g.HasToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, g)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GuestWithRSVP is the dashboard view: a guest plus the RSVP submitted
|
||||
// against their token, if any. RSVP fields are nil when no response yet.
|
||||
type GuestWithRSVP struct {
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP TABLE IF EXISTS refresh_tokens;
|
||||
DROP TABLE IF EXISTS password_reset_tokens;
|
||||
DROP TABLE IF EXISTS email_verification_tokens;
|
||||
|
||||
ALTER TABLE users
|
||||
DROP COLUMN IF EXISTS email_verified_at,
|
||||
DROP COLUMN IF EXISTS email_verified,
|
||||
DROP COLUMN IF EXISTS password_hash;
|
||||
@@ -0,0 +1,37 @@
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS password_hash TEXT,
|
||||
ADD COLUMN IF NOT EXISTS email_verified BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
ADD COLUMN IF NOT EXISTS email_verified_at TIMESTAMPTZ;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS email_verification_tokens (
|
||||
token_hash TEXT PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
consumed_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_user ON email_verification_tokens(user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS password_reset_tokens (
|
||||
token_hash TEXT PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
consumed_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_password_reset_tokens_user ON password_reset_tokens(user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
token_hash TEXT PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
revoked_at TIMESTAMPTZ,
|
||||
replaced_by TEXT REFERENCES refresh_tokens(token_hash) ON DELETE SET NULL,
|
||||
user_agent TEXT,
|
||||
ip_address INET,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_active ON refresh_tokens(user_id) WHERE revoked_at IS NULL;
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP TABLE IF EXISTS unsubscribes;
|
||||
|
||||
ALTER TABLE notifications
|
||||
DROP COLUMN IF EXISTS complained,
|
||||
DROP COLUMN IF EXISTS bounce_type,
|
||||
DROP COLUMN IF EXISTS provider_message_id;
|
||||
|
||||
DROP INDEX IF EXISTS idx_notifications_provider_message_id;
|
||||
@@ -0,0 +1,23 @@
|
||||
-- Block D — real notifications: bounce / complaint tracking + suppression list.
|
||||
-- The `delivered_at` column already exists from 0001.
|
||||
|
||||
CREATE EXTENSION IF NOT EXISTS "citext";
|
||||
|
||||
ALTER TABLE notifications
|
||||
ADD COLUMN IF NOT EXISTS provider_message_id TEXT,
|
||||
ADD COLUMN IF NOT EXISTS bounce_type TEXT, -- 'permanent' | 'transient' | NULL
|
||||
ADD COLUMN IF NOT EXISTS complained BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_notifications_provider_message_id
|
||||
ON notifications(provider_message_id)
|
||||
WHERE provider_message_id IS NOT NULL;
|
||||
|
||||
-- Suppression list: any email present here gets a silent no-op on send.
|
||||
-- Populated by bounce / complaint webhooks and by guest-initiated
|
||||
-- unsubscribe clicks.
|
||||
CREATE TABLE IF NOT EXISTS unsubscribes (
|
||||
email CITEXT PRIMARY KEY,
|
||||
reason TEXT,
|
||||
source TEXT NOT NULL DEFAULT 'manual', -- 'bounce' | 'complaint' | 'manual' | 'user'
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP INDEX IF EXISTS idx_subscriptions_subscription;
|
||||
DROP INDEX IF EXISTS idx_subscriptions_customer;
|
||||
DROP INDEX IF EXISTS uniq_subscriptions_active_user;
|
||||
DROP TABLE IF EXISTS subscriptions;
|
||||
@@ -0,0 +1,30 @@
|
||||
-- Block F — Stripe subscriptions. One row per Stripe customer + (optional)
|
||||
-- active subscription. Free-tier hosts never get a row; their tier is
|
||||
-- inferred at read time.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
stripe_customer_id TEXT NOT NULL,
|
||||
stripe_subscription_id TEXT,
|
||||
tier TEXT NOT NULL CHECK (tier IN ('free','pro','business')),
|
||||
status TEXT NOT NULL CHECK (status IN ('active','past_due','canceled','incomplete','trialing','unpaid')),
|
||||
current_period_end TIMESTAMPTZ,
|
||||
cancel_at_period_end BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
-- A user may have at most one *granting* subscription at a time. We
|
||||
-- include trialing + past_due because those still convey access (past_due
|
||||
-- is the grace period before Stripe gives up on the card).
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uniq_subscriptions_active_user
|
||||
ON subscriptions(user_id)
|
||||
WHERE status IN ('active','past_due','trialing');
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_customer
|
||||
ON subscriptions(stripe_customer_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_subscription
|
||||
ON subscriptions(stripe_subscription_id)
|
||||
WHERE stripe_subscription_id IS NOT NULL;
|
||||
@@ -0,0 +1,6 @@
|
||||
DROP INDEX IF EXISTS idx_users_active_email;
|
||||
|
||||
ALTER TABLE users
|
||||
DROP COLUMN IF EXISTS privacy_policy_accepted_at,
|
||||
DROP COLUMN IF EXISTS terms_accepted_at,
|
||||
DROP COLUMN IF EXISTS deleted_at;
|
||||
@@ -0,0 +1,20 @@
|
||||
-- Block H — privacy compliance.
|
||||
--
|
||||
-- Adds the columns needed for:
|
||||
-- - Right to erasure (DELETE /me): soft-delete first, hard-delete via
|
||||
-- a future cron after a 30-day grace window so an accidental click
|
||||
-- is recoverable.
|
||||
-- - Terms / privacy-policy acceptance gate (set on signup; older
|
||||
-- accounts re-prompted via the frontend on next login).
|
||||
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS terms_accepted_at TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS privacy_policy_accepted_at TIMESTAMPTZ;
|
||||
|
||||
-- Most lookups (login, /me, etc.) want to exclude soft-deleted users.
|
||||
-- A partial index keeps the active subset fast without bloating writes
|
||||
-- for the rare deleted rows.
|
||||
CREATE INDEX IF NOT EXISTS idx_users_active_email
|
||||
ON users(email)
|
||||
WHERE deleted_at IS NULL;
|
||||
@@ -0,0 +1,223 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Subscription mirrors the subscriptions table row. Stored as a thin
|
||||
// projection of the Stripe state — we don't try to mirror every field,
|
||||
// just what middleware + handlers need to decide access.
|
||||
type Subscription struct {
|
||||
ID uuid.UUID
|
||||
UserID uuid.UUID
|
||||
StripeCustomerID string
|
||||
StripeSubscriptionID *string
|
||||
Tier string
|
||||
Status string
|
||||
CurrentPeriodEnd *time.Time
|
||||
CancelAtPeriodEnd bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ErrSubscriptionNotFound is returned when no row matches the lookup.
|
||||
var ErrSubscriptionNotFound = errors.New("subscription not found")
|
||||
|
||||
type SubscriptionRepo struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewSubscriptionRepo(db *DB) *SubscriptionRepo {
|
||||
return &SubscriptionRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
const subscriptionColumns = `
|
||||
id, user_id, stripe_customer_id, stripe_subscription_id,
|
||||
tier, status, current_period_end, cancel_at_period_end,
|
||||
created_at, updated_at
|
||||
`
|
||||
|
||||
// GetActiveByUser returns the user's currently-granting subscription
|
||||
// (active / trialing / past_due). Returns ErrSubscriptionNotFound when
|
||||
// the user has no row at all — caller treats that as free tier.
|
||||
func (r *SubscriptionRepo) GetActiveByUser(ctx context.Context, userID uuid.UUID) (*Subscription, error) {
|
||||
const q = `
|
||||
SELECT ` + subscriptionColumns + `
|
||||
FROM subscriptions
|
||||
WHERE user_id = $1
|
||||
AND status IN ('active','past_due','trialing')
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
`
|
||||
return r.scanOne(ctx, q, userID)
|
||||
}
|
||||
|
||||
// GetByCustomer fetches by Stripe customer id — webhooks use this since
|
||||
// the event payload identifies the customer, not the user.
|
||||
func (r *SubscriptionRepo) GetByCustomer(ctx context.Context, customerID string) (*Subscription, error) {
|
||||
const q = `
|
||||
SELECT ` + subscriptionColumns + `
|
||||
FROM subscriptions WHERE stripe_customer_id = $1
|
||||
ORDER BY updated_at DESC LIMIT 1
|
||||
`
|
||||
return r.scanOne(ctx, q, customerID)
|
||||
}
|
||||
|
||||
// FindCustomerID returns the Stripe customer id we've already created
|
||||
// for this user, or "" if none exists yet. Avoids creating duplicate
|
||||
// Stripe customers across checkout sessions.
|
||||
func (r *SubscriptionRepo) FindCustomerID(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
const q = `
|
||||
SELECT stripe_customer_id FROM subscriptions
|
||||
WHERE user_id = $1 ORDER BY created_at ASC LIMIT 1
|
||||
`
|
||||
var id string
|
||||
if err := r.pool.QueryRow(ctx, q, userID).Scan(&id); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// UpsertParams collects everything an upsert needs. Pointer types denote
|
||||
// "skip writing this column" (used when a webhook only carries partial
|
||||
// data — we never want to clobber tier or period info we don't have).
|
||||
type UpsertParams struct {
|
||||
UserID uuid.UUID
|
||||
StripeCustomerID string
|
||||
StripeSubscriptionID *string
|
||||
Tier *string
|
||||
Status *string
|
||||
CurrentPeriodEnd *time.Time
|
||||
CancelAtPeriodEnd *bool
|
||||
}
|
||||
|
||||
// Upsert inserts a new row or updates an existing one keyed by
|
||||
// stripe_customer_id. Used by both the checkout-success handler and the
|
||||
// webhook subscription-lifecycle handler.
|
||||
func (r *SubscriptionRepo) Upsert(ctx context.Context, p UpsertParams) (*Subscription, error) {
|
||||
const q = `
|
||||
INSERT INTO subscriptions (
|
||||
user_id, stripe_customer_id, stripe_subscription_id,
|
||||
tier, status, current_period_end, cancel_at_period_end
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, $3,
|
||||
COALESCE($4, 'free'), COALESCE($5, 'incomplete'),
|
||||
$6, COALESCE($7, FALSE)
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
RETURNING ` + subscriptionColumns + `
|
||||
`
|
||||
|
||||
row := r.pool.QueryRow(ctx, q,
|
||||
p.UserID, p.StripeCustomerID, p.StripeSubscriptionID,
|
||||
p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd,
|
||||
)
|
||||
sub, err := scanSubscription(row)
|
||||
if err == nil {
|
||||
return sub, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Race or duplicate insert — fall back to an explicit update on the
|
||||
// stripe_customer_id (the FK to Stripe's source of truth).
|
||||
const upd = `
|
||||
UPDATE subscriptions SET
|
||||
stripe_subscription_id = COALESCE($3, stripe_subscription_id),
|
||||
tier = COALESCE($4, tier),
|
||||
status = COALESCE($5, status),
|
||||
current_period_end = COALESCE($6, current_period_end),
|
||||
cancel_at_period_end = COALESCE($7, cancel_at_period_end),
|
||||
updated_at = now()
|
||||
WHERE user_id = $1 AND stripe_customer_id = $2
|
||||
RETURNING ` + subscriptionColumns + `
|
||||
`
|
||||
row = r.pool.QueryRow(ctx, upd,
|
||||
p.UserID, p.StripeCustomerID, p.StripeSubscriptionID,
|
||||
p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd,
|
||||
)
|
||||
return scanSubscription(row)
|
||||
}
|
||||
|
||||
// UpdateByCustomer patches the subscription row keyed by Stripe customer
|
||||
// id. Used by webhooks where we have the customer reference but not
|
||||
// always the user id.
|
||||
func (r *SubscriptionRepo) UpdateByCustomer(ctx context.Context, customerID string, p UpsertParams) error {
|
||||
const q = `
|
||||
UPDATE subscriptions SET
|
||||
stripe_subscription_id = COALESCE($2, stripe_subscription_id),
|
||||
tier = COALESCE($3, tier),
|
||||
status = COALESCE($4, status),
|
||||
current_period_end = COALESCE($5, current_period_end),
|
||||
cancel_at_period_end = COALESCE($6, cancel_at_period_end),
|
||||
updated_at = now()
|
||||
WHERE stripe_customer_id = $1
|
||||
`
|
||||
_, err := r.pool.Exec(ctx, q,
|
||||
customerID, p.StripeSubscriptionID,
|
||||
p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// CountEventsInCurrentMonth returns how many events the user has created
|
||||
// since the 1st of the current UTC month. Used for free-tier "1 event /
|
||||
// month" and Pro-tier "10 events / month" enforcement.
|
||||
func (r *SubscriptionRepo) CountEventsInCurrentMonth(ctx context.Context, userID uuid.UUID) (int, error) {
|
||||
const q = `
|
||||
SELECT count(*) FROM events
|
||||
WHERE host_id = $1
|
||||
AND created_at >= date_trunc('month', now() AT TIME ZONE 'UTC')
|
||||
`
|
||||
var n int
|
||||
if err := r.pool.QueryRow(ctx, q, userID).Scan(&n); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// CountGuestsByEvent returns the current guest count for an event. Used
|
||||
// for per-event guest cap enforcement.
|
||||
func (r *SubscriptionRepo) CountGuestsByEvent(ctx context.Context, eventID uuid.UUID) (int, error) {
|
||||
var n int
|
||||
if err := r.pool.QueryRow(ctx,
|
||||
`SELECT count(*) FROM guests WHERE event_id = $1`, eventID,
|
||||
).Scan(&n); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionRepo) scanOne(ctx context.Context, q string, args ...any) (*Subscription, error) {
|
||||
sub, err := scanSubscription(r.pool.QueryRow(ctx, q, args...))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func scanSubscription(s rowScanner) (*Subscription, error) {
|
||||
var sub Subscription
|
||||
if err := s.Scan(
|
||||
&sub.ID, &sub.UserID, &sub.StripeCustomerID, &sub.StripeSubscriptionID,
|
||||
&sub.Tier, &sub.Status, &sub.CurrentPeriodEnd, &sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt, &sub.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
@@ -60,6 +60,38 @@ func (r *TokenRepo) GetByHash(ctx context.Context, hash string) (*domain.Token,
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
// RotateForGuest replaces the guest's existing token with a freshly-minted
|
||||
// one in a single transaction. The old token row is hard-deleted (the
|
||||
// guests.id UNIQUE constraint requires it, and "the old link must stop
|
||||
// working" is the point). Cascade-deletes the old access_logs rows that
|
||||
// reference it via the token_id FK with ON DELETE SET NULL — those rows
|
||||
// stay, with token_id nulled.
|
||||
func (r *TokenRepo) RotateForGuest(ctx context.Context, p CreateTokenParams) (*domain.Token, error) {
|
||||
tx, err := r.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
if _, err := tx.Exec(ctx, `DELETE FROM tokens WHERE guest_id = $1`, p.GuestID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
const q = `
|
||||
INSERT INTO tokens (guest_id, token_hash, expires_at, status)
|
||||
VALUES ($1, $2, $3, 'active')
|
||||
RETURNING id, guest_id, token_hash, expires_at, status, used_at, created_at
|
||||
`
|
||||
row := tx.QueryRow(ctx, q, p.GuestID, p.TokenHash, p.ExpiresAt)
|
||||
tk, err := scanToken(row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
func (r *TokenRepo) MarkUsed(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE tokens SET status = 'used', used_at = now()
|
||||
|
||||
+141
-9
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -20,12 +21,38 @@ func NewUserRepo(db *DB) *UserRepo {
|
||||
return &UserRepo{pool: db.Pool}
|
||||
}
|
||||
|
||||
func (r *UserRepo) Create(ctx context.Context, email, name string) (*domain.User, error) {
|
||||
const userColumns = `id, email, name,
|
||||
COALESCE(password_hash, '') AS password_hash,
|
||||
email_verified, email_verified_at,
|
||||
deleted_at,
|
||||
terms_accepted_at, privacy_policy_accepted_at,
|
||||
created_at, updated_at`
|
||||
|
||||
type CreateUserParams struct {
|
||||
Email string
|
||||
Name string
|
||||
PasswordHash string
|
||||
AcceptTerms bool // when true, records terms + privacy acceptance now
|
||||
}
|
||||
|
||||
func (r *UserRepo) Create(ctx context.Context, p CreateUserParams) (*domain.User, error) {
|
||||
const q = `
|
||||
INSERT INTO users (email, name) VALUES ($1, $2)
|
||||
RETURNING id, email, name, created_at, updated_at
|
||||
`
|
||||
row := r.pool.QueryRow(ctx, q, strings.ToLower(strings.TrimSpace(email)), strings.TrimSpace(name))
|
||||
INSERT INTO users (
|
||||
email, name, password_hash,
|
||||
terms_accepted_at, privacy_policy_accepted_at
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, NULLIF($3, ''),
|
||||
CASE WHEN $4 THEN now() ELSE NULL END,
|
||||
CASE WHEN $4 THEN now() ELSE NULL END
|
||||
)
|
||||
RETURNING ` + userColumns
|
||||
row := r.pool.QueryRow(ctx, q,
|
||||
normaliseEmail(p.Email),
|
||||
strings.TrimSpace(p.Name),
|
||||
p.PasswordHash,
|
||||
p.AcceptTerms,
|
||||
)
|
||||
u, err := scanUser(row)
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
@@ -37,9 +64,12 @@ func (r *UserRepo) Create(ctx context.Context, email, name string) (*domain.User
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (r *UserRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
const q = `SELECT id, email, name, created_at, updated_at FROM users WHERE email = $1`
|
||||
u, err := scanUser(r.pool.QueryRow(ctx, q, strings.ToLower(strings.TrimSpace(email))))
|
||||
// GetByID returns an active (non-soft-deleted) user. Soft-deleted users
|
||||
// are treated as "not found" by the API surface — keeps the deletion
|
||||
// flow safe by default.
|
||||
func (r *UserRepo) GetByID(ctx context.Context, id uuid.UUID) (*domain.User, error) {
|
||||
const q = `SELECT ` + userColumns + ` FROM users WHERE id = $1 AND deleted_at IS NULL`
|
||||
u, err := scanUser(r.pool.QueryRow(ctx, q, id))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrUserNotFound
|
||||
@@ -49,10 +79,112 @@ func (r *UserRepo) GetByEmail(ctx context.Context, email string) (*domain.User,
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// GetByEmail mirrors GetByID — soft-deleted users vanish from email
|
||||
// lookups (so signup/login don't match a tombstoned record).
|
||||
func (r *UserRepo) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
const q = `SELECT ` + userColumns + ` FROM users WHERE email = $1 AND deleted_at IS NULL`
|
||||
u, err := scanUser(r.pool.QueryRow(ctx, q, normaliseEmail(email)))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// SoftDelete marks the user as deleted and clears their PII-bearing
|
||||
// fields. A nightly cron (TBD in ops) will hard-delete rows older than
|
||||
// 30 days. Until then the row exists for audit + recovery if the user
|
||||
// changes their mind.
|
||||
func (r *UserRepo) SoftDelete(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE users SET
|
||||
deleted_at = now(),
|
||||
updated_at = now(),
|
||||
-- Tombstone PII so the soft-deleted row can sit for 30 days
|
||||
-- without holding the user's real email + name in cleartext.
|
||||
-- The original values are gone from the API surface from the
|
||||
-- moment SoftDelete returns.
|
||||
email = 'deleted-' || id::text || '@deleted.local',
|
||||
name = 'Deleted user',
|
||||
password_hash = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AcceptTerms records that the user has consented to the current terms
|
||||
// of service and privacy policy. Idempotent — re-accepting just resets
|
||||
// the timestamp.
|
||||
func (r *UserRepo) AcceptTerms(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE users SET
|
||||
terms_accepted_at = now(),
|
||||
privacy_policy_accepted_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepo) MarkEmailVerified(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE users
|
||||
SET email_verified = TRUE,
|
||||
email_verified_at = COALESCE(email_verified_at, now()),
|
||||
updated_at = now()
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepo) UpdatePasswordHash(ctx context.Context, id uuid.UUID, hash string) error {
|
||||
tag, err := r.pool.Exec(ctx, `
|
||||
UPDATE users SET password_hash = $2, updated_at = now() WHERE id = $1
|
||||
`, id, hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanUser(s rowScanner) (*domain.User, error) {
|
||||
var u domain.User
|
||||
if err := s.Scan(&u.ID, &u.Email, &u.Name, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
if err := s.Scan(
|
||||
&u.ID, &u.Email, &u.Name,
|
||||
&u.PasswordHash,
|
||||
&u.EmailVerified, &u.EmailVerifiedAt,
|
||||
&u.DeletedAt,
|
||||
&u.TermsAcceptedAt, &u.PrivacyPolicyAcceptedAt,
|
||||
&u.CreatedAt, &u.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func normaliseEmail(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user