feat: ship Tier 1 — auth, authz, rate limits, real notifications, CSV import, billing, backups/DR, privacy

Closes every block in docs/TIER1_PLAN.md from the Claude-scope side. The
homelab / cloud setup steps (SES verification, restore drill, lawyer-
drafted ToS) remain operator-owned but are unblocked.

Block A — Authentication
- Migration 0003: password_hash, email_verified, email_verification_tokens,
  password_reset_tokens, refresh_tokens (with replaced_by family chain).
- Bcrypt hasher, HS256 JWT signer, single-use refresh tokens with rotation
  + replay-detection (revokes the family on reuse).
- /auth/signup, /login, /refresh, /logout, /verify-email,
  /forgot-password, /reset-password — enumeration-safe.
- requireAuth middleware + GET /me.
- Frontend useAuth/useApi with auto-refresh-on-401, login/signup/verify/
  forgot/reset pages, route-guard middleware.

Block B — Authorisation
- EventRepo.GetForHost; Update/Delete scoped by host_id.
- All host routes behind requireAuth + ownership; cross-tenant returns
  404 (no enumeration). ?host_id removed.
- WS auth via short-lived single-use tickets (POST /auth/ws-ticket).
- Tests: TestCrossTenantIsolation — 9 probes.

Block C — Rate limiting
- Redis sliding-window via Lua (atomic ZADD+ZCARD+PEXPIRE).
- Per-route limits matching the plan (signup IP, login IP+email, RSVP/
  access by token, events/guests/tokens by user_id).
- 429 with Retry-After header and JSON body.
- Auth lockout: 5 failed logins → account locked, only password reset
  clears it.
- Frontend: useErrMessage normalises 429 + locked messaging.

Block D — Real notifications
- Migration 0004: provider_message_id, bounce_type, complained columns
  + unsubscribes (CITEXT) suppression table.
- Branded HTML + plaintext templates for verification, reset, invitation,
  confirmation, reminder. Per-page templates avoid html/template's
  contextual-escape collisions.
- Senders: SESv2, Twilio (SMS), SMTP (Mailpit-friendly), Resend HTTP.
- PickEmailSender priority Resend > SMTP > SES > Log — system boots
  cleanly in dev with Mailpit; production flips one env var.
- Webhook endpoints (Twilio status + SES SNS) — bounces add to suppression;
  signature verification stubbed pending creds.
- Auto-send: POST /tokens publishes invitation.send; notifier renders +
  delivers via the configured backend; suppression list honoured.
- Bulk + per-row invitation flow: POST /events/{id}/guests/invitations/bulk
  returns per-guest tokens so phone-only guests can be SMS'd manually.
- Unsubscribe: signed HMAC token (no TTL) + /unsubscribe/[token] page.
- WhatsApp Option A+: wa.me click-to-chat wizard with per-guest progress
  tracking, isLikelyE164 validation, edit-from-wizard.
- Token rotate (POST /tokens/rotate) invalidates the old URL — used by
  the regenerate-link flow.
- Mailpit added to docker-compose for dev inbox.

Block E — CSV import
- Streaming parser: tolerant header detection, UTF-8 BOM + UTF-16 LE/BE
  decoding, row-level validation, 5,000-row cap.
- Strict E.164 phone validation with helpful error message.
- POST /preview + /import + GET /template; preview UI on event page;
  atomic per-batch with dedup on existing emails.

Phone capture across UI
- PhoneInput component: country picker (~50 ISO codes) + national input +
  live E.164 preview + inline length validation.
- Used in Add Guest and Edit Guest modals. Smart paste-handling extracts
  country code from full E.164 strings.

Block F — Billing (Stripe)
- Migration 0005: subscriptions table (user_id → tier/status/period_end +
  Stripe customer/sub ids). Partial unique index keeps one granting sub
  per user.
- internal/billing: Tier + Limits model (Free 1/50, Pro 10/1000, Business
  ∞/5000), Stripe SDK wrapper with IgnoreAPIVersionMismatch for newer
  account API versions.
- /billing/checkout-session, /billing/portal, /billing/status,
  /webhooks/stripe (signature-verified, lifecycle events).
- Tier enforcement: 402 on POST /events, /guests, /import with
  {error, reason, tier, used, limit, upgrade_url} body.
- Frontend: useBilling composable, /dashboard/billing page (current plan,
  usage bars, tier cards), global UpgradeModal triggered by useApi's
  402 interceptor.
- Customer portal kept for self-service cancel/payment-method changes.

Block G — Backups & DR (application side)
- Every migration has a tested .down.sql.
- TestMigrationRoundtrip applies all ups → all downs → all ups against a
  fresh container; catches asymmetric down migrations.
- cmd/restore-verify: 28-check post-restore invariant tool (schema
  presence, no orphans across 10 FK relationships, email uniqueness,
  single-active subscription, row-count snapshot).
- docs/RUNBOOK_RESTORE.md: 9-step restore procedure with RTO/RPO
  targets, drill instructions, rollback path.

Block H — Privacy compliance (application side)
- Migration 0006: deleted_at + terms_accepted_at + privacy_policy_accepted_at
  on users. Partial index on email for live-only uniqueness.
- GET /me/data-export — synchronous JSON dump (user, events, guests,
  tokens, rsvps, access_logs, notifications).
- DELETE /me — soft-delete with PII scrub + refresh-token revocation;
  re-signup with same email works.
- POST /me/accept-terms — idempotent consent recording.
- Frontend /privacy + /terms placeholder pages with substantive (pending
  legal review) copy; footer links; signup terms checkbox; TermsGateModal
  for accounts created before the rollout; export + delete buttons on
  /dashboard/billing.

Tests
- All migrations verified up/down/up.
- Integration suite: TestE2EHappyPath, TestAuthFlow, TestCrossTenantIsolation,
  TestRateLimitSignup, TestLoginLockout, TestUnsubscribeFlow,
  TestSESBounceWebhook, TestTwilioStatusWebhook, TestCsvImportFlow,
  TestCsvImportAtomicRollback, TestBulkIssueInvitations, TestBulkIssueExplicitSubset,
  TestTokenIssuePublishesInvitation, TestTokenIssueWithoutGuestEmailSkipsInvitation,
  TestGuestUpdate, TestGuestDelete, TestTokenRotate, TestSMTPSenderAgainstMailpit,
  TestFreeTierEventLimit, TestFreeTierGuestLimit, TestBusinessTierBypassesLimits,
  TestDataExport, TestDeleteMe, TestAcceptTerms, TestMigrationRoundtrip.
  Full suite runs in ~120s against real Postgres + NATS + Redis + Mailpit.
- Unit suite green across internal/auth, internal/csvimport,
  internal/notification, internal/ratelimit, internal/domain.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Kwaku Danso
2026-05-16 23:54:22 +01:00
parent a0ed34f860
commit 59b8781659
124 changed files with 13702 additions and 445 deletions
+5 -8
View File
@@ -1,12 +1,10 @@
package api
import (
"errors"
"net/http"
"sort"
"time"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
@@ -42,16 +40,15 @@ type activityItem struct {
// for an event, sorted newest first. Frontends use this on dashboard mount
// to backfill the live monitor with history.
func (h *activityHandler) list(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, err := h.events.Get(r.Context(), eventID); err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load event")
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
+557
View File
@@ -0,0 +1,557 @@
package api
import (
"context"
"encoding/json"
"errors"
"log/slog"
"net/http"
"net/mail"
"net/url"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/auth"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/ratelimit"
"github.com/alchemistkay/guestguard/internal/storage"
)
const refreshCookieName = "gg_refresh"
type authHandler struct {
logger *slog.Logger
users *storage.UserRepo
verifications *storage.EmailVerificationRepo
resets *storage.PasswordResetRepo
refreshes *storage.RefreshTokenRepo
hasher *auth.PasswordHasher
signer *auth.JWTSigner
emails auth.EmailSender
lockout *auth.LockoutTracker
limiter *ratelimit.Limiter
publicBaseURL string
emailVerificationTTL time.Duration
passwordResetTTL time.Duration
refreshTTL time.Duration
cookieDomain string
cookieSecure bool
}
type authHandlerDeps struct {
Logger *slog.Logger
Users *storage.UserRepo
Verifications *storage.EmailVerificationRepo
Resets *storage.PasswordResetRepo
Refreshes *storage.RefreshTokenRepo
Hasher *auth.PasswordHasher
Signer *auth.JWTSigner
Emails auth.EmailSender
Lockout *auth.LockoutTracker
Limiter *ratelimit.Limiter
PublicBaseURL string
EmailVerificationTTL time.Duration
PasswordResetTTL time.Duration
RefreshTTL time.Duration
CookieDomain string
CookieSecure bool
}
func newAuthHandler(d authHandlerDeps) *authHandler {
return &authHandler{
logger: d.Logger,
users: d.Users,
verifications: d.Verifications,
resets: d.Resets,
refreshes: d.Refreshes,
hasher: d.Hasher,
signer: d.Signer,
emails: d.Emails,
lockout: d.Lockout,
limiter: d.Limiter,
publicBaseURL: strings.TrimRight(d.PublicBaseURL, "/"),
emailVerificationTTL: d.EmailVerificationTTL,
passwordResetTTL: d.PasswordResetTTL,
refreshTTL: d.RefreshTTL,
cookieDomain: d.CookieDomain,
cookieSecure: d.CookieSecure,
}
}
// --- request/response types ---
type signupRequest struct {
Email string `json:"email"`
Name string `json:"name"`
Password string `json:"password"`
AcceptTerms bool `json:"accept_terms"`
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type verifyEmailRequest struct {
Token string `json:"token"`
}
type forgotPasswordRequest struct {
Email string `json:"email"`
}
type resetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type authSuccess struct {
AccessToken string `json:"access_token"`
ExpiresAt time.Time `json:"expires_at"`
User *domain.User `json:"user"`
}
// --- handlers ---
// POST /auth/signup
func (h *authHandler) signup(w http.ResponseWriter, r *http.Request) {
var req signupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if _, err := mail.ParseAddress(req.Email); err != nil {
writeError(w, http.StatusBadRequest, "email is invalid")
return
}
if strings.TrimSpace(req.Name) == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
hash, err := h.hasher.Hash(req.Password)
if err != nil {
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
h.logger.Error("hash password", "err", err)
writeError(w, http.StatusInternalServerError, "failed to create user")
return
}
u, err := h.users.Create(r.Context(), storage.CreateUserParams{
Email: req.Email,
Name: req.Name,
PasswordHash: hash,
AcceptTerms: req.AcceptTerms,
})
if err != nil {
if errors.Is(err, domain.ErrEmailTaken) {
// Don't leak which addresses are registered. Still return 201 and
// trigger a "if-you-already-have-an-account" email asynchronously
// (skipped for the stub). On real auth this should send a "you
// tried to sign up again, here's a reset link" email.
h.logger.Info("signup attempted with existing email", "email", req.Email)
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
return
}
h.logger.Error("create user", "err", err)
writeError(w, http.StatusInternalServerError, "failed to create user")
return
}
if err := h.sendVerificationEmail(r.Context(), u); err != nil {
h.logger.Error("send verification email", "err", err, "user_id", u.ID)
// Don't fail the signup — user can request a resend.
}
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
}
// POST /auth/login
func (h *authHandler) login(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if req.Email == "" || req.Password == "" {
writeError(w, http.StatusBadRequest, "email and password required")
return
}
// Per-(IP + email) sliding-window — 10 per 5 minutes per the plan.
if !h.checkRate(w, r, "login", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
10, 5*time.Minute) {
return
}
u, err := h.users.GetByEmail(r.Context(), req.Email)
if err != nil || u.PasswordHash == "" {
_, _ = h.lockout.RecordFailure(r.Context(), req.Email, nil)
writeError(w, http.StatusUnauthorized, "invalid email or password")
return
}
// If the account is already locked, reject before doing a bcrypt compare.
locked, _ := h.lockout.IsLocked(r.Context(), u.ID)
if locked {
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
return
}
if err := h.hasher.Verify(u.PasswordHash, req.Password); err != nil {
locked, _ := h.lockout.RecordFailure(r.Context(), req.Email, &u.ID)
if locked {
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
return
}
writeError(w, http.StatusUnauthorized, "invalid email or password")
return
}
if !u.EmailVerified {
writeError(w, http.StatusForbidden, "email not verified")
return
}
h.lockout.ClearOnSuccess(r.Context(), req.Email)
if err := h.issueSession(w, r, u); err != nil {
h.logger.Error("issue session", "err", err, "user_id", u.ID)
writeError(w, http.StatusInternalServerError, "failed to start session")
return
}
}
// checkRate consults the limiter (when one is configured) and writes a 429
// response if the budget is exhausted. Returns false if the caller should
// stop handling the request.
func (h *authHandler) checkRate(w http.ResponseWriter, r *http.Request, name, key string, limit int, window time.Duration) bool {
if h.limiter == nil || key == "" {
return true
}
res, err := h.limiter.Allow(r.Context(), name, key, limit, window)
if err != nil {
h.logger.Warn("ratelimit error (failing open)", "rule", name, "err", err)
return true
}
if !res.Allowed {
retry := int(res.RetryAfter.Round(time.Second).Seconds())
if retry < 1 {
retry = 1
}
w.Header().Set("Retry-After", strconv.Itoa(retry))
writeJSON(w, http.StatusTooManyRequests, map[string]any{
"error": "rate limit exceeded",
"retry_after": retry,
})
return false
}
return true
}
// POST /auth/refresh
func (h *authHandler) refresh(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(refreshCookieName)
if err != nil || cookie.Value == "" {
writeError(w, http.StatusUnauthorized, "missing refresh token")
return
}
oldHash := auth.HashOpaque(cookie.Value)
existing, err := h.refreshes.Get(r.Context(), oldHash)
if err != nil {
if errors.Is(err, domain.ErrAuthTokenNotFound) {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "invalid refresh token")
return
}
h.logger.Error("lookup refresh", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
if existing.RevokedAt != nil {
// Replay of a revoked token. Revoke the family.
_ = h.refreshes.RevokeAllForUser(r.Context(), existing.UserID)
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "refresh token reused")
return
}
if time.Now().After(existing.ExpiresAt) {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "refresh token expired")
return
}
u, err := h.users.GetByID(r.Context(), existing.UserID)
if err != nil {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "user not found")
return
}
newRaw, newHash, err := auth.NewOpaqueToken()
if err != nil {
h.logger.Error("mint refresh", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
exp := time.Now().Add(h.refreshTTL)
if err := h.refreshes.Rotate(r.Context(), oldHash, storage.CreateRefreshTokenParams{
Hash: newHash,
UserID: u.ID,
ExpiresAt: exp,
UserAgent: r.UserAgent(),
IPAddress: clientIP(r),
}); err != nil {
if errors.Is(err, domain.ErrRefreshTokenRevoked) {
h.clearRefreshCookie(w)
writeError(w, http.StatusUnauthorized, "refresh token reused")
return
}
h.logger.Error("rotate refresh", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
if err != nil {
h.logger.Error("sign access", "err", err)
writeError(w, http.StatusInternalServerError, "refresh failed")
return
}
h.setRefreshCookie(w, newRaw, exp)
writeJSON(w, http.StatusOK, authSuccess{
AccessToken: access,
ExpiresAt: accessExp,
User: u,
})
}
// POST /auth/logout
func (h *authHandler) logout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(refreshCookieName)
if err == nil && cookie.Value != "" {
_ = h.refreshes.Revoke(r.Context(), auth.HashOpaque(cookie.Value))
}
h.clearRefreshCookie(w)
w.WriteHeader(http.StatusNoContent)
}
// POST /auth/verify-email
func (h *authHandler) verifyEmail(w http.ResponseWriter, r *http.Request) {
var req verifyEmailRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
writeError(w, http.StatusBadRequest, "token required")
return
}
uid, err := h.verifications.Consume(r.Context(), auth.HashOpaque(req.Token))
if err != nil {
switch {
case errors.Is(err, domain.ErrAuthTokenNotFound):
writeError(w, http.StatusBadRequest, "invalid token")
case errors.Is(err, domain.ErrAuthTokenConsumed):
writeError(w, http.StatusBadRequest, "token already used")
case errors.Is(err, domain.ErrAuthTokenExpired):
writeError(w, http.StatusBadRequest, "token expired")
default:
h.logger.Error("consume verification", "err", err)
writeError(w, http.StatusInternalServerError, "verification failed")
}
return
}
if err := h.users.MarkEmailVerified(r.Context(), uid); err != nil {
h.logger.Error("mark verified", "err", err, "user_id", uid)
writeError(w, http.StatusInternalServerError, "verification failed")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "verified"})
}
// POST /auth/forgot-password
func (h *authHandler) forgotPassword(w http.ResponseWriter, r *http.Request) {
var req forgotPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if !h.checkRate(w, r, "forgot_password", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
3, time.Hour) {
return
}
// Always respond 202 to avoid leaking whether the email exists.
defer func() { writeJSON(w, http.StatusAccepted, map[string]string{"status": "if_known_email_sent"}) }()
u, err := h.users.GetByEmail(r.Context(), req.Email)
if err != nil {
return
}
raw, hash, err := auth.NewOpaqueToken()
if err != nil {
h.logger.Error("mint reset", "err", err)
return
}
exp := time.Now().Add(h.passwordResetTTL)
if err := h.resets.Create(r.Context(), u.ID, hash, exp); err != nil {
h.logger.Error("persist reset", "err", err)
return
}
link := h.publicBaseURL + "/reset-password/" + url.PathEscape(raw)
if err := h.emails.SendPasswordReset(r.Context(), u.Email, u.Name, link); err != nil {
h.logger.Error("send reset email", "err", err)
}
}
// POST /auth/reset-password
func (h *authHandler) resetPassword(w http.ResponseWriter, r *http.Request) {
var req resetPasswordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
writeError(w, http.StatusBadRequest, "token and new_password required")
return
}
newHash, err := h.hasher.Hash(req.NewPassword)
if err != nil {
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
h.logger.Error("hash password", "err", err)
writeError(w, http.StatusInternalServerError, "reset failed")
return
}
uid, err := h.resets.Consume(r.Context(), auth.HashOpaque(req.Token))
if err != nil {
switch {
case errors.Is(err, domain.ErrAuthTokenNotFound):
writeError(w, http.StatusBadRequest, "invalid token")
case errors.Is(err, domain.ErrAuthTokenConsumed):
writeError(w, http.StatusBadRequest, "token already used")
case errors.Is(err, domain.ErrAuthTokenExpired):
writeError(w, http.StatusBadRequest, "token expired")
default:
h.logger.Error("consume reset", "err", err)
writeError(w, http.StatusInternalServerError, "reset failed")
}
return
}
if err := h.users.UpdatePasswordHash(r.Context(), uid, newHash); err != nil {
h.logger.Error("update password", "err", err, "user_id", uid)
writeError(w, http.StatusInternalServerError, "reset failed")
return
}
// Invalidate all existing sessions.
_ = h.refreshes.RevokeAllForUser(r.Context(), uid)
// Resetting the password is the canonical "unlock" path for the
// account lockout that triggers after repeated bad-credential attempts.
if u, err := h.users.GetByID(r.Context(), uid); err == nil {
_ = h.lockout.ClearForUser(r.Context(), uid, u.Email)
}
writeJSON(w, http.StatusOK, map[string]string{"status": "password_reset"})
}
// --- helpers ---
func (h *authHandler) sendVerificationEmail(ctx context.Context, u *domain.User) error {
raw, hash, err := auth.NewOpaqueToken()
if err != nil {
return err
}
if err := h.verifications.Create(ctx, u.ID, hash, time.Now().Add(h.emailVerificationTTL)); err != nil {
return err
}
link := h.publicBaseURL + "/verify-email?token=" + url.QueryEscape(raw)
return h.emails.SendVerification(ctx, u.Email, u.Name, link)
}
func (h *authHandler) issueSession(w http.ResponseWriter, r *http.Request, u *domain.User) error {
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
if err != nil {
return err
}
raw, hash, err := auth.NewOpaqueToken()
if err != nil {
return err
}
refreshExp := time.Now().Add(h.refreshTTL)
if err := h.refreshes.Create(r.Context(), storage.CreateRefreshTokenParams{
Hash: hash,
UserID: u.ID,
ExpiresAt: refreshExp,
UserAgent: r.UserAgent(),
IPAddress: clientIP(r),
}); err != nil {
return err
}
h.setRefreshCookie(w, raw, refreshExp)
writeJSON(w, http.StatusOK, authSuccess{
AccessToken: access,
ExpiresAt: accessExp,
User: u,
})
return nil
}
func (h *authHandler) setRefreshCookie(w http.ResponseWriter, value string, expires time.Time) {
http.SetCookie(w, &http.Cookie{
Name: refreshCookieName,
Value: value,
Path: "/auth",
Domain: h.cookieDomain,
Expires: expires,
MaxAge: int(time.Until(expires).Seconds()),
HttpOnly: true,
Secure: h.cookieSecure,
SameSite: http.SameSiteLaxMode,
})
}
func (h *authHandler) clearRefreshCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: refreshCookieName,
Value: "",
Path: "/auth",
Domain: h.cookieDomain,
MaxAge: -1,
HttpOnly: true,
Secure: h.cookieSecure,
SameSite: http.SameSiteLaxMode,
})
}
// --- requireAuth middleware ---
type ctxKey int
const userIDCtxKey ctxKey = iota
func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) {
v, ok := ctx.Value(userIDCtxKey).(uuid.UUID)
return v, ok
}
func requireAuth(signer *auth.JWTSigner) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := r.Header.Get("Authorization")
if !strings.HasPrefix(h, "Bearer ") {
writeError(w, http.StatusUnauthorized, "missing bearer token")
return
}
raw := strings.TrimSpace(strings.TrimPrefix(h, "Bearer "))
claims, err := signer.Parse(raw)
if err != nil {
if errors.Is(err, auth.ErrExpiredJWT) {
writeError(w, http.StatusUnauthorized, "token expired")
return
}
writeError(w, http.StatusUnauthorized, "invalid token")
return
}
ctx := context.WithValue(r.Context(), userIDCtxKey, claims.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
+44
View File
@@ -0,0 +1,44 @@
package api
import (
"errors"
"net/http"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
// hostFromContext returns the authed user's id, or writes 401 and returns
// false. Used by host-facing handlers as the first line in the function.
func hostFromContext(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
uid, ok := UserIDFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthenticated")
return uuid.Nil, false
}
return uid, true
}
// requireEventOwner fetches the event and confirms the authed user owns it.
// On mismatch (or missing event) it returns 404 — never 403 — so a cross-
// tenant probe cannot tell the difference between "event doesn't exist" and
// "exists but belongs to someone else".
func requireEventOwner(
w http.ResponseWriter,
r *http.Request,
events *storage.EventRepo,
eventID, hostID uuid.UUID,
) (*domain.Event, bool) {
ev, err := events.GetForHost(r.Context(), eventID, hostID)
if err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return nil, false
}
writeError(w, http.StatusInternalServerError, "failed to load event")
return nil, false
}
return ev, true
}
+206
View File
@@ -0,0 +1,206 @@
package api
import (
"encoding/json"
"errors"
"log/slog"
"net/http"
"strings"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/storage"
)
type billingHandler struct {
logger *slog.Logger
stripe *billing.Client
users *storage.UserRepo
subscriptions *storage.SubscriptionRepo
publicBaseURL string
}
type checkoutSessionRequest struct {
Tier string `json:"tier"`
}
type checkoutSessionResponse struct {
URL string `json:"url"`
}
// POST /billing/checkout-session — returns the Stripe Checkout URL the
// frontend redirects the host to. Mints a Stripe customer on first use
// and persists it so repeat calls reuse the same customer.
func (h *billingHandler) checkoutSession(w http.ResponseWriter, r *http.Request) {
if !h.stripeEnabled(w) {
return
}
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
var req checkoutSessionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
tier := billing.Tier(strings.ToLower(req.Tier))
if tier != billing.TierPro && tier != billing.TierBusiness {
writeError(w, http.StatusBadRequest, "tier must be 'pro' or 'business'")
return
}
price, err := h.stripe.PriceFor(tier)
if err != nil {
writeError(w, http.StatusServiceUnavailable, "this tier is not configured yet — contact support")
return
}
user, err := h.users.GetByID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
existingCustomerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load billing record")
return
}
customerID, err := h.stripe.CreateOrGetCustomer(hostID.String(), user.Email, user.Name, existingCustomerID)
if err != nil {
h.logger.Error("stripe customer", "err", err)
writeError(w, http.StatusBadGateway, "stripe customer error")
return
}
if existingCustomerID == "" {
// First time — write a placeholder row so the customer id sticks.
if _, err := h.subscriptions.Upsert(r.Context(), storage.UpsertParams{
UserID: hostID,
StripeCustomerID: customerID,
}); err != nil {
h.logger.Error("upsert sub placeholder", "err", err)
}
}
base := strings.TrimRight(h.publicBaseURL, "/")
url, err := h.stripe.CreateCheckoutSession(billing.CheckoutSessionParams{
CustomerID: customerID,
PriceID: price,
SuccessURL: base + "/dashboard?billing=success",
CancelURL: base + "/dashboard?billing=cancelled",
})
if err != nil {
h.logger.Error("stripe checkout session", "err", err)
writeError(w, http.StatusBadGateway, "stripe checkout error")
return
}
writeJSON(w, http.StatusOK, checkoutSessionResponse{URL: url})
}
type portalSessionResponse struct {
URL string `json:"url"`
}
// POST /billing/portal — returns the customer portal URL so the user
// can manage their payment method, view invoices, or cancel. 404 when
// the user has no Stripe customer yet (they're still on free).
func (h *billingHandler) portalSession(w http.ResponseWriter, r *http.Request) {
if !h.stripeEnabled(w) {
return
}
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
customerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load billing record")
return
}
if customerID == "" {
writeError(w, http.StatusNotFound, "no billing account yet — subscribe first")
return
}
url, err := h.stripe.CreatePortalSession(customerID, strings.TrimRight(h.publicBaseURL, "/")+"/dashboard")
if err != nil {
h.logger.Error("stripe portal", "err", err)
writeError(w, http.StatusBadGateway, "stripe portal error")
return
}
writeJSON(w, http.StatusOK, portalSessionResponse{URL: url})
}
type subscriptionStatusResponse struct {
Tier string `json:"tier"`
Status string `json:"status"`
CurrentPeriodEnd string `json:"current_period_end,omitempty"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
Limits struct {
EventsPerMonth int `json:"events_per_month"`
GuestsPerEvent int `json:"guests_per_event"`
} `json:"limits"`
Usage struct {
EventsThisMonth int `json:"events_this_month"`
} `json:"usage"`
PortalAvailable bool `json:"portal_available"`
}
// GET /billing/status — returns the host's current tier + limits +
// usage. The frontend uses this to render the billing page and the
// 402-modal copy ("you used X of Y events this month").
func (h *billingHandler) status(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
tier := billing.TierFree
status := "active"
var periodEnd string
cancelAtPeriodEnd := false
portalAvailable := false
if h.subscriptions != nil {
sub, err := h.subscriptions.GetActiveByUser(r.Context(), hostID)
switch {
case err == nil:
tier = billing.Tier(sub.Tier)
status = sub.Status
cancelAtPeriodEnd = sub.CancelAtPeriodEnd
if sub.CurrentPeriodEnd != nil {
periodEnd = sub.CurrentPeriodEnd.Format("2006-01-02T15:04:05Z")
}
portalAvailable = sub.StripeCustomerID != ""
case errors.Is(err, storage.ErrSubscriptionNotFound):
// Free tier — leave defaults.
default:
writeError(w, http.StatusInternalServerError, "failed to load subscription")
return
}
}
limits := billing.LimitsFor(tier)
events, _ := h.subscriptions.CountEventsInCurrentMonth(r.Context(), hostID)
resp := subscriptionStatusResponse{
Tier: string(tier),
Status: status,
CurrentPeriodEnd: periodEnd,
CancelAtPeriodEnd: cancelAtPeriodEnd,
PortalAvailable: portalAvailable,
}
resp.Limits.EventsPerMonth = limits.EventsPerMonth
resp.Limits.GuestsPerEvent = limits.GuestsPerEvent
resp.Usage.EventsThisMonth = events
writeJSON(w, http.StatusOK, resp)
}
// stripeEnabled returns true if the billing client is configured, else
// writes 503 and returns false. The /billing/status path skips this so
// the frontend can render a "free tier" page in dev environments.
func (h *billingHandler) stripeEnabled(w http.ResponseWriter) bool {
if h.stripe == nil || !h.stripe.Enabled() {
writeError(w, http.StatusServiceUnavailable, "billing is not configured on this instance")
return false
}
return true
}
+157
View File
@@ -0,0 +1,157 @@
package api
import (
"context"
"errors"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/storage"
)
// tierEnforcer wraps the SubscriptionRepo with policy decisions. Lives
// here (not in storage) because the policy is HTTP-shaped: we map the
// outcome to 402 + an upgrade URL.
type tierEnforcer struct {
subs *storage.SubscriptionRepo
publicBaseURL string
}
func newTierEnforcer(subs *storage.SubscriptionRepo, publicBaseURL string) *tierEnforcer {
return &tierEnforcer{subs: subs, publicBaseURL: publicBaseURL}
}
// currentTier returns the host's effective tier. ErrSubscriptionNotFound
// means "no granting subscription on file" → free. Other DB errors
// bubble up.
func (e *tierEnforcer) currentTier(ctx context.Context, hostID uuid.UUID) (billing.Tier, error) {
if e == nil || e.subs == nil {
return billing.TierFree, nil
}
sub, err := e.subs.GetActiveByUser(ctx, hostID)
if err != nil {
if errors.Is(err, storage.ErrSubscriptionNotFound) {
return billing.TierFree, nil
}
return "", err
}
if !billing.StatusGrantsAccess(sub.Status) {
return billing.TierFree, nil
}
tier := billing.Tier(sub.Tier)
if !tier.Valid() {
return billing.TierFree, nil
}
return tier, nil
}
// allowEventCreate verifies the host's monthly event budget. Returns
// true when the request may proceed. On denial it writes a 402 with the
// upgrade hint and returns false.
func (e *tierEnforcer) allowEventCreate(w http.ResponseWriter, r *http.Request, hostID uuid.UUID) bool {
if e == nil || e.subs == nil {
return true
}
tier, err := e.currentTier(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to check plan")
return false
}
limit := billing.LimitsFor(tier).EventsPerMonth
if limit < 0 {
return true
}
used, err := e.subs.CountEventsInCurrentMonth(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to count events")
return false
}
if used >= limit {
e.writePaymentRequired(w, "events_per_month", tier, used, limit,
"You've reached your monthly event limit on the "+strings.ToUpper(string(tier))+" plan.")
return false
}
return true
}
// allowGuestCreate verifies the per-event guest cap. Same shape as
// allowEventCreate.
func (e *tierEnforcer) allowGuestCreate(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID) bool {
if e == nil || e.subs == nil {
return true
}
tier, err := e.currentTier(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to check plan")
return false
}
limit := billing.LimitsFor(tier).GuestsPerEvent
if limit < 0 {
return true
}
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to count guests")
return false
}
if used >= limit {
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
"This event has reached the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
return false
}
return true
}
// allowGuestImport is the CSV-import variant: check the cap against
// existing + incoming row count up-front, before we start the
// transaction. Dedup may shrink the actual insert count later — that's
// OK, we just stay on the safe side.
func (e *tierEnforcer) allowGuestImport(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID, incoming int) bool {
if e == nil || e.subs == nil || incoming == 0 {
return true
}
tier, err := e.currentTier(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to check plan")
return false
}
limit := billing.LimitsFor(tier).GuestsPerEvent
if limit < 0 {
return true
}
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to count guests")
return false
}
if used+incoming > limit {
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
"This import would exceed the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
return false
}
return true
}
type paymentRequiredBody struct {
Error string `json:"error"`
Reason string `json:"reason"`
Tier string `json:"tier"`
Used int `json:"used"`
Limit int `json:"limit"`
UpgradeURL string `json:"upgrade_url"`
}
func (e *tierEnforcer) writePaymentRequired(w http.ResponseWriter, reason string, tier billing.Tier, used, limit int, msg string) {
body := paymentRequiredBody{
Error: msg,
Reason: reason,
Tier: string(tier),
Used: used,
Limit: limit,
UpgradeURL: strings.TrimRight(e.publicBaseURL, "/") + "/dashboard/billing",
}
writeJSON(w, http.StatusPaymentRequired, body)
}
+177
View File
@@ -0,0 +1,177 @@
package api
import (
"errors"
"io"
"net/http"
"github.com/alchemistkay/guestguard/internal/csvimport"
"github.com/alchemistkay/guestguard/internal/storage"
)
const (
// 1 MB cap on uploads. With ~200 bytes per row that's ~5,000 guests —
// matches the row cap in csvimport.DefaultMaxRows.
csvMaxBytes = 1 << 20
)
type csvImportHandler struct {
guests *storage.GuestRepo
events *storage.EventRepo
enforcer *tierEnforcer
}
type importResponse struct {
Added int `json:"added"`
Skipped int `json:"skipped"`
SkippedEmails []string `json:"skipped_emails,omitempty"`
Errors []csvimport.RowError `json:"errors,omitempty"`
TotalCount int `json:"total_count"`
}
type previewResponse struct {
Rows []csvimport.Row `json:"rows"`
Errors []csvimport.RowError `json:"errors,omitempty"`
TotalCount int `json:"total_count"`
}
// POST /events/{id}/guests/import/preview — parse + validate but don't write.
// Used by the frontend to show a "is this what you meant?" table before commit.
func (h *csvImportHandler) preview(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
body, ok := readCSVUpload(w, r)
if !ok {
return
}
defer body.Close()
res, err := csvimport.Parse(body, csvimport.Options{})
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, previewResponse{
Rows: res.Rows,
Errors: res.Errors,
TotalCount: res.TotalCount,
})
}
// POST /events/{id}/guests/import — parse, validate, and commit valid rows
// in a single transaction. Rows with row-level errors are reported back
// but don't prevent the rest from importing.
func (h *csvImportHandler) commit(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
body, ok := readCSVUpload(w, r)
if !ok {
return
}
defer body.Close()
parsed, err := csvimport.Parse(body, csvimport.Options{})
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
// Plan enforcement: prevent an import that, even if perfectly
// dedup-free, would exceed the per-event guest cap. We check upfront
// (current count + parsed-row count) and reject with 402 if it'd
// overflow. False positives on dedup-heavy CSVs are acceptable —
// host can dedupe and re-upload.
if !h.enforcer.allowGuestImport(w, r, hostID, eventID, len(parsed.Rows)) {
return
}
rows := make([]storage.BulkImportRow, 0, len(parsed.Rows))
for _, r := range parsed.Rows {
rows = append(rows, storage.BulkImportRow{
Name: r.Name,
Email: r.Email,
Phone: r.Phone,
PlusOnes: r.PlusOnes,
})
}
res, err := h.guests.BulkImportGuests(r.Context(), eventID, rows)
if err != nil {
writeError(w, http.StatusInternalServerError, "import failed")
return
}
writeJSON(w, http.StatusOK, importResponse{
Added: res.Added,
Skipped: res.Skipped,
SkippedEmails: res.SkippedEmails,
Errors: parsed.Errors,
TotalCount: parsed.TotalCount,
})
}
// GET /events/{id}/guests/import/template — download a sample CSV. Auth is
// applied at the route level; ownership is verified so an attacker can't
// probe the existence of an event by hitting this endpoint.
func (h *csvImportHandler) template(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-import-template.csv"`)
_, _ = w.Write([]byte(csvimport.TemplateCSV))
}
// readCSVUpload returns the multipart file body (capped at csvMaxBytes) or
// writes an error and returns (nil, false). Accepted shapes:
//
// - multipart/form-data with a "file" field (the drag-drop UI uses this)
// - any other Content-Type — the raw body is treated as the CSV (curl-friendly)
func readCSVUpload(w http.ResponseWriter, r *http.Request) (io.ReadCloser, bool) {
r.Body = http.MaxBytesReader(w, r.Body, csvMaxBytes)
if err := r.ParseMultipartForm(csvMaxBytes); err == nil && r.MultipartForm != nil {
files := r.MultipartForm.File["file"]
if len(files) == 0 {
writeError(w, http.StatusBadRequest, `form field "file" is required`)
return nil, false
}
f, err := files[0].Open()
if err != nil {
writeError(w, http.StatusBadRequest, "cannot read uploaded file")
return nil, false
}
return f, true
} else if err != nil && !errors.Is(err, http.ErrNotMultipart) {
writeError(w, http.StatusBadRequest, "invalid multipart body")
return nil, false
}
// Fall through: raw body as CSV.
return r.Body, true
}
+39 -28
View File
@@ -15,11 +15,11 @@ import (
)
type eventHandler struct {
repo *storage.EventRepo
repo *storage.EventRepo
enforcer *tierEnforcer
}
type createEventRequest struct {
HostID string `json:"host_id"`
Name string `json:"name"`
Slug string `json:"slug"`
EventDate time.Time `json:"event_date"`
@@ -32,6 +32,14 @@ type createEventRequest struct {
var slugRe = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
if !h.enforcer.allowEventCreate(w, r, hostID) {
return
}
var req createEventRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
@@ -51,12 +59,6 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
return
}
hostID, err := uuid.Parse(req.HostID)
if err != nil {
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
return
}
status := domain.EventStatus(req.Status)
if status == "" {
status = domain.EventStatusDraft
@@ -89,37 +91,31 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
}
func (h *eventHandler) get(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
id, ok := parseIDParam(w, r, "id")
if !ok {
return
}
ev, err := h.repo.Get(r.Context(), id)
if err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load event")
ev, ok := requireEventOwner(w, r, h.repo, id, hostID)
if !ok {
return
}
writeJSON(w, http.StatusOK, ev)
}
func (h *eventHandler) list(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
q := r.URL.Query()
limit := atoiOr(q.Get("limit"), 50)
offset := atoiOr(q.Get("offset"), 0)
var hostID uuid.UUID
if v := q.Get("host_id"); v != "" {
parsed, err := uuid.Parse(v)
if err != nil {
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
return
}
hostID = parsed
}
events, err := h.repo.List(r.Context(), hostID, limit, offset)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to list events")
@@ -146,6 +142,10 @@ type updateEventRequest struct {
}
func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
id, ok := parseIDParam(w, r, "id")
if !ok {
return
@@ -180,7 +180,7 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
params.Status = &s
}
ev, err := h.repo.Update(r.Context(), id, params)
ev, err := h.repo.Update(r.Context(), id, hostID, params)
if err != nil {
switch {
case errors.Is(err, domain.ErrEventNotFound):
@@ -196,11 +196,15 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
}
func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
id, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if err := h.repo.Delete(r.Context(), id); err != nil {
if err := h.repo.Delete(r.Context(), id, hostID); err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
@@ -212,7 +216,14 @@ func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
}
func parseIDParam(w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, bool) {
raw := r.PathValue(name)
return parseRawUUID(w, name, r.PathValue(name))
}
func parseRawUUID(w http.ResponseWriter, name, raw string) (uuid.UUID, bool) {
if raw == "" {
writeError(w, http.StatusBadRequest, name+" is required")
return uuid.Nil, false
}
id, err := uuid.Parse(raw)
if err != nil {
writeError(w, http.StatusBadRequest, name+" must be a valid uuid")
+107 -9
View File
@@ -10,8 +10,9 @@ import (
)
type guestHandler struct {
guests *storage.GuestRepo
events *storage.EventRepo
guests *storage.GuestRepo
events *storage.EventRepo
enforcer *tierEnforcer
}
type createGuestRequest struct {
@@ -24,16 +25,18 @@ type createGuestRequest struct {
}
func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, err := h.events.Get(r.Context(), eventID); err != nil {
if errors.Is(err, domain.ErrEventNotFound) {
writeError(w, http.StatusNotFound, "event not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load event")
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
if !h.enforcer.allowGuestCreate(w, r, hostID, eventID) {
return
}
@@ -67,11 +70,106 @@ func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusCreated, g)
}
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
type updateGuestRequest struct {
Name *string `json:"name"`
Email *string `json:"email"`
Phone *string `json:"phone"`
PlusOnes *int `json:"plus_ones"`
}
// PATCH /events/{id}/guests/{guest_id} — patch a guest's contact info.
// Fields omitted from the body are left untouched. Empty strings for
// email/phone clear those columns to NULL.
func (h *guestHandler) update(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
guestID, ok := parseIDParam(w, r, "guest_id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
var req updateGuestRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if req.Name != nil && *req.Name == "" {
writeError(w, http.StatusBadRequest, "name cannot be empty")
return
}
if req.PlusOnes != nil && *req.PlusOnes < 0 {
writeError(w, http.StatusBadRequest, "plus_ones must be >= 0")
return
}
g, err := h.guests.Update(r.Context(), eventID, guestID, storage.UpdateGuestParams{
Name: req.Name,
Email: req.Email,
Phone: req.Phone,
PlusOnes: req.PlusOnes,
})
if err != nil {
if errors.Is(err, domain.ErrGuestNotFound) {
writeError(w, http.StatusNotFound, "guest not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to update guest")
return
}
writeJSON(w, http.StatusOK, g)
}
// DELETE /events/{id}/guests/{guest_id} — remove a guest from an event.
// Cascade-deletes their token, rsvp, access logs, notifications.
func (h *guestHandler) delete(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
guestID, ok := parseIDParam(w, r, "guest_id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
if err := h.guests.Delete(r.Context(), eventID, guestID); err != nil {
if errors.Is(err, domain.ErrGuestNotFound) {
writeError(w, http.StatusNotFound, "guest not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to delete guest")
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
q := r.URL.Query()
limit := atoiOr(q.Get("limit"), 100)
offset := atoiOr(q.Get("offset"), 0)
+32
View File
@@ -0,0 +1,32 @@
package api
import (
"net/http"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
type meHandler struct {
users *storage.UserRepo
}
// GET /me — returns the authenticated user. Used by the frontend to bootstrap
// after a page reload (with a fresh access token from /auth/refresh).
func (h *meHandler) get(w http.ResponseWriter, r *http.Request) {
uid, ok := UserIDFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthenticated")
return
}
u, err := h.users.GetByID(r.Context(), uid)
if err != nil {
if err == domain.ErrUserNotFound {
writeError(w, http.StatusUnauthorized, "user not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
writeJSON(w, http.StatusOK, u)
}
+255
View File
@@ -0,0 +1,255 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
"time"
"github.com/google/uuid"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
// privacyHandler holds the GDPR-style "your data, your choice" endpoints:
// data export, account deletion, and terms-acceptance recording.
type privacyHandler struct {
logger *slog.Logger
users *storage.UserRepo
events *storage.EventRepo
guests *storage.GuestRepo
tokens *storage.TokenRepo
rsvps *storage.RSVPRepo
access *storage.AccessLogRepo
notifs *storage.DB // raw pool access for the export queries
refresh *storage.RefreshTokenRepo
}
// DataExport is the shape of the JSON the host downloads from
// GET /me/data-export. We don't paginate or stream — for the scale
// GuestGuard hosts have, a single response is reasonable. If a host
// ever has 100k+ access logs we'll switch to async + email-a-link.
type DataExport struct {
ExportedAt time.Time `json:"exported_at"`
Format string `json:"format"`
User *domain.User `json:"user"`
Events []*domain.Event `json:"events"`
Guests []*domain.Guest `json:"guests"`
Tokens []exportedToken `json:"tokens"`
RSVPs []exportedRSVP `json:"rsvps"`
AccessLogs []exportedAccess `json:"access_logs"`
Notifs []exportedNotif `json:"notifications"`
}
type exportedToken struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
type exportedRSVP struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
Response string `json:"response"`
PlusOnes int `json:"plus_ones"`
SubmittedAt time.Time `json:"submitted_at"`
}
type exportedAccess struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
RiskScore *int `json:"risk_score,omitempty"`
Flagged bool `json:"flagged"`
CreatedAt time.Time `json:"created_at"`
}
type exportedNotif struct {
ID uuid.UUID `json:"id"`
GuestID uuid.UUID `json:"guest_id"`
Channel string `json:"channel"`
Type string `json:"type"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// GET /me/data-export — returns every record the system holds about the
// authenticated user. The Content-Disposition header makes browsers
// offer a download rather than rendering inline.
func (h *privacyHandler) dataExport(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
user, err := h.users.GetByID(r.Context(), hostID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
export := DataExport{
ExportedAt: time.Now().UTC(),
Format: "guestguard.v1",
User: user,
}
// Events the user hosts.
events, err := h.events.List(r.Context(), hostID, 1000, 0)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load events")
return
}
export.Events = events
// For each event, pull guests + tokens + rsvps + access_logs + notifications.
for _, ev := range events {
guests, err := h.guests.ListByEvent(r.Context(), ev.ID, 5000, 0)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load guests")
return
}
export.Guests = append(export.Guests, guests...)
for _, g := range guests {
// Token (at most one per guest, but query as a list for symmetry).
if err := h.appendTokens(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: tokens", "err", err)
}
if err := h.appendRSVPs(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: rsvps", "err", err)
}
if err := h.appendAccess(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: access", "err", err)
}
if err := h.appendNotifs(r.Context(), g.ID, &export); err != nil {
h.logger.Warn("export: notifs", "err", err)
}
}
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-data-export.json"`)
writeJSON(w, http.StatusOK, export)
}
func (h *privacyHandler) appendTokens(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, expires_at, status, created_at
FROM tokens WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var t exportedToken
if err := rows.Scan(&t.ID, &t.GuestID, &t.ExpiresAt, &t.Status, &t.CreatedAt); err != nil {
return err
}
out.Tokens = append(out.Tokens, t)
}
return rows.Err()
}
func (h *privacyHandler) appendRSVPs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, response::text, plus_ones, submitted_at
FROM rsvps WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var r exportedRSVP
if err := rows.Scan(&r.ID, &r.GuestID, &r.Response, &r.PlusOnes, &r.SubmittedAt); err != nil {
return err
}
out.RSVPs = append(out.RSVPs, r)
}
return rows.Err()
}
func (h *privacyHandler) appendAccess(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, risk_score, flagged, created_at
FROM access_logs WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var a exportedAccess
var rs *int
if err := rows.Scan(&a.ID, &a.GuestID, &rs, &a.Flagged, &a.CreatedAt); err != nil {
return err
}
a.RiskScore = rs
out.AccessLogs = append(out.AccessLogs, a)
}
return rows.Err()
}
func (h *privacyHandler) appendNotifs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
rows, err := h.notifs.Pool.Query(ctx, `
SELECT id, guest_id, channel::text, type::text, status::text, created_at
FROM notifications WHERE guest_id = $1
`, guestID)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var n exportedNotif
if err := rows.Scan(&n.ID, &n.GuestID, &n.Channel, &n.Type, &n.Status, &n.CreatedAt); err != nil {
return err
}
out.Notifs = append(out.Notifs, n)
}
return rows.Err()
}
// DELETE /me — soft-deletes the host's account. All sessions are
// revoked immediately. A hard delete happens via a separate cron 30
// days later (TBD ops work). The user is logged out from all devices
// as a side effect of revoking the refresh tokens.
func (h *privacyHandler) deleteMe(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
if err := h.users.SoftDelete(r.Context(), hostID); err != nil {
if errors.Is(err, domain.ErrUserNotFound) {
writeError(w, http.StatusNotFound, "user not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to delete account")
return
}
// Best-effort: revoke refresh tokens so other sessions log out too.
// Failure here is logged but doesn't roll back the soft-delete — the
// access tokens (JWT) will still expire on their own ~15 minute TTL.
if err := h.refresh.RevokeAllForUser(r.Context(), hostID); err != nil {
h.logger.Warn("delete-me: revoke refresh tokens", "err", err, "user_id", hostID)
}
w.WriteHeader(http.StatusNoContent)
}
// POST /me/accept-terms — records that the authenticated user accepts
// the current ToS + privacy policy. Idempotent. Used by both the
// onboarding gate (existing accounts created before T&C were enforced)
// and any future "we updated our terms" re-acceptance flow.
func (h *privacyHandler) acceptTerms(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
if err := h.users.AcceptTerms(r.Context(), hostID); err != nil {
if errors.Is(err, domain.ErrUserNotFound) {
writeError(w, http.StatusNotFound, "user not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to record acceptance")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "accepted"})
}
+37
View File
@@ -0,0 +1,37 @@
package api
import (
"net/http"
)
// ipKey is the rate-limit key for endpoints scoped by source IP only
// (e.g. POST /auth/signup). XFF/X-Real-IP are honoured because in the
// homelab the API sits behind Traefik.
func ipKey(r *http.Request) string {
return clientIP(r)
}
// pathKey returns a path-parameter as the rate-limit key — used for the
// token-scoped endpoints so an attacker brute-forcing a single token is
// limited regardless of the IPs they rotate through.
func pathKey(name string) KeyFunc {
return func(r *http.Request) string {
return r.PathValue(name)
}
}
// userIDKey extracts the authenticated user id from the request context.
// Returns "" when the route isn't behind requireAuth, in which case the
// middleware bypasses (fail-open) — the route's own auth layer handles
// rejection.
func userIDKey(r *http.Request) string {
uid, ok := UserIDFromContext(r.Context())
if !ok {
return ""
}
return uid.String()
}
// KeyFunc mirrors ratelimit.KeyFunc so call sites don't have to import the
// inner package.
type KeyFunc = func(r *http.Request) string
+264 -43
View File
@@ -5,63 +5,167 @@ import (
"net/http"
"time"
"github.com/redis/go-redis/v9"
"github.com/alchemistkay/guestguard/internal/auth"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/notification"
"github.com/alchemistkay/guestguard/internal/ratelimit"
"github.com/alchemistkay/guestguard/internal/storage"
)
type Server struct {
logger *slog.Logger
db *storage.DB
hub *Hub
users *userHandler
events *eventHandler
guests *guestHandler
tokens *tokenHandler
rsvps *rsvpHandler
activity *activityHandler
ws *wsHandler
health *healthHandler
logger *slog.Logger
db *storage.DB
hub *Hub
authH *authHandler
me *meHandler
events *eventHandler
guests *guestHandler
tokens *tokenHandler
rsvps *rsvpHandler
activity *activityHandler
ws *wsHandler
wsTicket *wsTicketHandler
health *healthHandler
signer *auth.JWTSigner
limiter *ratelimit.Limiter
unsub *unsubscribeHandler
webhooks *webhookHandler
csv *csvImportHandler
billing *billingHandler
stripeWH *stripeWebhookHandler
privacy *privacyHandler
}
type ServerDeps struct {
Logger *slog.Logger
DB *storage.DB
Hub *Hub
AccessPublisher accessPublisher
RSVPPublisher rsvpPublisher
FraudScorer fraudScorer
TokenTTL time.Duration
AccessPublisher accessPublisher
RSVPPublisher rsvpPublisher
InvitationPublisher invitationPublisher
FraudScorer fraudScorer
TokenTTL time.Duration
// Auth
JWTSecret string
JWTIssuer string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
EmailVerificationTTL time.Duration
PasswordResetTTL time.Duration
PublicBaseURL string
RefreshCookieDomain string
RefreshCookieSecure bool
EmailSender auth.EmailSender
WSTicketTTL time.Duration
// Rate limiting / abuse controls
Redis *redis.Client
LoginLockoutMax int // failed attempts before account lockout (default 5)
LoginFailWindow time.Duration // counter TTL (default 15 min)
// Notifications / unsubscribe
NotificationRepo *notification.Repo
SuppressionRepo *notification.SuppressionRepo
UnsubscribeSigner *notification.UnsubscribeSigner
// Billing (Block F). Nil StripeClient leaves billing disabled — the
// system still boots and runs, all users sit on the free tier with
// its limits enforced; /billing/* returns 503.
StripeClient *billing.Client
}
func NewServer(deps ServerDeps) *Server {
func NewServer(deps ServerDeps) (*Server, error) {
eventRepo := storage.NewEventRepo(deps.DB)
guestRepo := storage.NewGuestRepo(deps.DB)
tokenRepo := storage.NewTokenRepo(deps.DB)
rsvpRepo := storage.NewRSVPRepo(deps.DB)
accessRepo := storage.NewAccessLogRepo(deps.DB)
userRepo := storage.NewUserRepo(deps.DB)
verifRepo := storage.NewEmailVerificationRepo(deps.DB)
resetRepo := storage.NewPasswordResetRepo(deps.DB)
refreshRepo := storage.NewRefreshTokenRepo(deps.DB)
subRepo := storage.NewSubscriptionRepo(deps.DB)
enforcer := newTierEnforcer(subRepo, deps.PublicBaseURL)
signer, err := auth.NewJWTSigner(deps.JWTSecret, deps.AccessTokenTTL, deps.JWTIssuer)
if err != nil {
return nil, err
}
hasher := auth.NewPasswordHasher()
emails := deps.EmailSender
if emails == nil {
emails = auth.LogEmailSender{Logger: deps.Logger}
}
hub := deps.Hub
if hub == nil {
hub = NewHub(deps.Logger)
}
wsTicketTTL := deps.WSTicketTTL
if wsTicketTTL <= 0 {
wsTicketTTL = 60 * time.Second
}
wsTickets := newWSTicketStore(wsTicketTTL)
var limiter *ratelimit.Limiter
var lockout *auth.LockoutTracker
if deps.Redis != nil {
limiter = ratelimit.New(deps.Redis, "gg:rl")
lockoutMax := deps.LoginLockoutMax
if lockoutMax <= 0 {
lockoutMax = 5
}
failWindow := deps.LoginFailWindow
if failWindow <= 0 {
failWindow = 15 * time.Minute
}
lockout = auth.NewLockoutTracker(deps.Redis, lockoutMax, failWindow)
}
authH := newAuthHandler(authHandlerDeps{
Logger: deps.Logger,
Users: userRepo,
Verifications: verifRepo,
Resets: resetRepo,
Refreshes: refreshRepo,
Hasher: hasher,
Signer: signer,
Emails: emails,
Lockout: lockout,
Limiter: limiter,
PublicBaseURL: deps.PublicBaseURL,
EmailVerificationTTL: deps.EmailVerificationTTL,
PasswordResetTTL: deps.PasswordResetTTL,
RefreshTTL: deps.RefreshTokenTTL,
CookieDomain: deps.RefreshCookieDomain,
CookieSecure: deps.RefreshCookieSecure,
})
return &Server{
logger: deps.Logger,
db: deps.DB,
hub: hub,
users: &userHandler{repo: userRepo},
events: &eventHandler{repo: eventRepo},
guests: &guestHandler{guests: guestRepo, events: eventRepo},
authH: authH,
me: &meHandler{users: userRepo},
events: &eventHandler{repo: eventRepo, enforcer: enforcer},
guests: &guestHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
tokens: &tokenHandler{
logger: deps.Logger,
guests: guestRepo,
tokens: tokenRepo,
events: eventRepo,
accessLogs: accessRepo,
gen: auth.NewGenerator(),
ttl: deps.TokenTTL,
pub: deps.AccessPublisher,
logger: deps.Logger,
guests: guestRepo,
tokens: tokenRepo,
events: eventRepo,
users: userRepo,
accessLogs: accessRepo,
gen: auth.NewGenerator(),
ttl: deps.TokenTTL,
pub: deps.AccessPublisher,
invitations: deps.InvitationPublisher,
publicBaseURL: deps.PublicBaseURL,
},
rsvps: &rsvpHandler{
logger: deps.Logger,
@@ -78,9 +182,46 @@ func NewServer(deps ServerDeps) *Server {
rsvps: rsvpRepo,
accessLogs: accessRepo,
},
ws: &wsHandler{logger: deps.Logger, hub: hub},
health: &healthHandler{pool: deps.DB.Pool},
}
ws: &wsHandler{logger: deps.Logger, hub: hub, tickets: wsTickets},
wsTicket: &wsTicketHandler{tickets: wsTickets, events: eventRepo},
health: &healthHandler{pool: deps.DB.Pool},
signer: signer,
limiter: limiter,
unsub: &unsubscribeHandler{
logger: deps.Logger,
signer: deps.UnsubscribeSigner,
suppress: deps.SuppressionRepo,
},
webhooks: &webhookHandler{
logger: deps.Logger,
notifs: deps.NotificationRepo,
suppress: deps.SuppressionRepo,
},
csv: &csvImportHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
billing: &billingHandler{
logger: deps.Logger,
stripe: deps.StripeClient,
users: userRepo,
subscriptions: subRepo,
publicBaseURL: deps.PublicBaseURL,
},
stripeWH: &stripeWebhookHandler{
logger: deps.Logger,
stripe: deps.StripeClient,
subs: subRepo,
},
privacy: &privacyHandler{
logger: deps.Logger,
users: userRepo,
events: eventRepo,
guests: guestRepo,
tokens: tokenRepo,
rsvps: rsvpRepo,
access: accessRepo,
notifs: deps.DB,
refresh: refreshRepo,
},
}, nil
}
func (s *Server) Hub() *Hub { return s.hub }
@@ -91,25 +232,104 @@ func (s *Server) Handler() http.Handler {
mux.HandleFunc("GET /health", s.health.live)
mux.HandleFunc("GET /health/ready", s.health.ready)
mux.HandleFunc("POST /users", s.users.upsert)
// Per-route rate limiters (no-op when Redis isn't wired).
authed := requireAuth(s.signer)
rl := func(name string, limit int, window time.Duration, keyFn KeyFunc, h http.Handler) http.Handler {
if s.limiter == nil {
return h
}
return s.limiter.Middleware(
ratelimit.Rule{Name: name, Limit: limit, Window: window},
keyFn,
s.logger,
)(h)
}
mux.HandleFunc("POST /events", s.events.create)
mux.HandleFunc("GET /events", s.events.list)
mux.HandleFunc("GET /events/{id}", s.events.get)
mux.HandleFunc("PATCH /events/{id}", s.events.update)
mux.HandleFunc("DELETE /events/{id}", s.events.delete)
// Anonymous auth endpoints — POST /auth/login + /auth/forgot-password
// rate-limit inside the handler (key includes the email body field).
mux.Handle("POST /auth/signup",
rl("auth_signup", 5, time.Hour, ipKey, http.HandlerFunc(s.authH.signup)))
mux.HandleFunc("POST /auth/login", s.authH.login)
mux.HandleFunc("POST /auth/refresh", s.authH.refresh)
mux.HandleFunc("POST /auth/logout", s.authH.logout)
mux.HandleFunc("POST /auth/verify-email", s.authH.verifyEmail)
mux.HandleFunc("POST /auth/forgot-password", s.authH.forgotPassword)
mux.HandleFunc("POST /auth/reset-password", s.authH.resetPassword)
mux.HandleFunc("POST /events/{id}/guests", s.guests.create)
mux.HandleFunc("GET /events/{id}/guests", s.guests.list)
mux.Handle("GET /me", authed(http.HandlerFunc(s.me.get)))
mux.Handle("POST /auth/ws-ticket", authed(http.HandlerFunc(s.wsTicket.issue)))
mux.HandleFunc("GET /events/{id}/activity", s.activity.list)
// Privacy / GDPR-style endpoints — host can export their data,
// delete their account, and record terms acceptance from the
// onboarding gate.
mux.Handle("GET /me/data-export", authed(http.HandlerFunc(s.privacy.dataExport)))
mux.Handle("DELETE /me", authed(http.HandlerFunc(s.privacy.deleteMe)))
mux.Handle("POST /me/accept-terms", authed(http.HandlerFunc(s.privacy.acceptTerms)))
mux.HandleFunc("POST /events/{id}/guests/{guest_id}/tokens", s.tokens.issue)
mux.HandleFunc("GET /access/{token}", s.tokens.access)
mux.HandleFunc("POST /rsvp/{token}", s.rsvps.submit)
// Host-facing event/guest/token writes are limited by user_id.
mux.Handle("POST /events",
authed(rl("events_create", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.events.create))))
mux.Handle("GET /events", authed(http.HandlerFunc(s.events.list)))
mux.Handle("GET /events/{id}", authed(http.HandlerFunc(s.events.get)))
mux.Handle("PATCH /events/{id}", authed(http.HandlerFunc(s.events.update)))
mux.Handle("DELETE /events/{id}", authed(http.HandlerFunc(s.events.delete)))
mux.Handle("POST /events/{id}/guests",
authed(rl("guests_create", 1000, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.create))))
mux.Handle("GET /events/{id}/guests", authed(http.HandlerFunc(s.guests.list)))
mux.Handle("PATCH /events/{id}/guests/{guest_id}",
authed(rl("guests_update", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.update))))
mux.Handle("DELETE /events/{id}/guests/{guest_id}",
authed(rl("guests_delete", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.delete))))
// CSV import (Block E). Preview is cheap (no DB writes), so we keep
// its budget separate from commit's daily-row-add limit.
mux.Handle("POST /events/{id}/guests/import/preview",
authed(rl("guests_import_preview", 30, time.Hour, userIDKey, http.HandlerFunc(s.csv.preview))))
mux.Handle("POST /events/{id}/guests/import",
authed(rl("guests_import_commit", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.csv.commit))))
mux.Handle("GET /events/{id}/guests/import/template", authed(http.HandlerFunc(s.csv.template)))
mux.Handle("GET /events/{id}/activity", authed(http.HandlerFunc(s.activity.list)))
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens",
authed(rl("tokens_issue", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.issue))))
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens/rotate",
authed(rl("tokens_rotate", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.rotate))))
mux.Handle("POST /events/{id}/guests/invitations/bulk",
authed(rl("tokens_bulk", 10, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.bulkIssue))))
// Guest-facing endpoints — rate-limited by the access token in the URL
// path so an attacker hammering a single invitation is slowed regardless
// of their source IP.
mux.Handle("GET /access/{token}",
rl("access", 60, time.Hour, pathKey("token"), http.HandlerFunc(s.tokens.access)))
mux.Handle("POST /rsvp/{token}",
rl("rsvp", 10, time.Hour, pathKey("token"), http.HandlerFunc(s.rsvps.submit)))
// WebSocket endpoint authenticates via single-use ticket on the query
// string (see POST /auth/ws-ticket).
mux.HandleFunc("GET /ws/events/{id}", s.ws.handle)
// Unsubscribe (signed token, no auth required — links live in emails).
mux.HandleFunc("GET /unsubscribe/{token}", s.unsub.preview)
mux.HandleFunc("POST /unsubscribe/{token}", s.unsub.confirm)
// Provider webhooks. Signature verification is enforced in the handler
// once GG_TWILIO_AUTH_TOKEN / GG_SES_WEBHOOK_SECRET are set.
mux.HandleFunc("POST /webhooks/twilio/status", s.webhooks.twilio)
mux.HandleFunc("POST /webhooks/ses/notifications", s.webhooks.ses)
// Billing (Block F). /billing/status is safe for everyone — returns
// free tier defaults when Stripe is unconfigured or the user has no
// subscription, so the frontend's plan page always has something to
// render. The action endpoints (checkout, portal) return 503 in dev
// without Stripe credentials.
mux.Handle("GET /billing/status", authed(http.HandlerFunc(s.billing.status)))
mux.Handle("POST /billing/checkout-session", authed(http.HandlerFunc(s.billing.checkoutSession)))
mux.Handle("POST /billing/portal", authed(http.HandlerFunc(s.billing.portalSession)))
mux.HandleFunc("POST /webhooks/stripe", s.stripeWH.handle)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, "not found")
})
@@ -128,9 +348,10 @@ func corsMiddleware(next http.Handler) http.Handler {
origin := r.Header.Get("Origin")
if origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Device-Fingerprint")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Device-Fingerprint")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
+163
View File
@@ -0,0 +1,163 @@
package api
import (
"encoding/json"
"io"
"log/slog"
"net/http"
"time"
"github.com/stripe/stripe-go/v82"
"github.com/alchemistkay/guestguard/internal/billing"
"github.com/alchemistkay/guestguard/internal/storage"
)
// stripeWebhookHandler accepts and verifies Stripe events, then
// projects subscription lifecycle changes onto the subscriptions table.
// We track only what middleware needs to decide access — tier + status +
// period bounds. Invoice events (payment failed / succeeded) are logged
// for observability; dunning automation lands in Block F3.
type stripeWebhookHandler struct {
logger *slog.Logger
stripe *billing.Client
subs *storage.SubscriptionRepo
}
// POST /webhooks/stripe — signature-verified Stripe event sink.
func (h *stripeWebhookHandler) handle(w http.ResponseWriter, r *http.Request) {
if h.stripe == nil || !h.stripe.Enabled() {
// Not configured on this instance — reject so a misrouted event
// isn't silently swallowed. Stripe will retry which is harmless.
w.WriteHeader(http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
writeError(w, http.StatusBadRequest, "read body")
return
}
defer r.Body.Close()
event, err := h.stripe.VerifyWebhook(body, r.Header.Get("Stripe-Signature"))
if err != nil {
h.logger.Warn("stripe webhook signature failed", "err", err)
writeError(w, http.StatusBadRequest, "invalid signature")
return
}
switch event.Type {
case "customer.subscription.created", "customer.subscription.updated":
h.applySubscription(r, event)
case "customer.subscription.deleted":
h.applySubscriptionDeleted(r, event)
case "invoice.payment_succeeded":
// Clear past_due if Stripe says payment caught up. Most flows are
// already covered by the subscription.updated event Stripe also
// fires — this is belt-and-braces.
h.applySubscription(r, event)
case "invoice.payment_failed":
h.logger.Warn("stripe invoice payment failed", "event_id", event.ID)
// Subscription.status will flip to past_due via the
// subscription.updated event Stripe fires alongside.
default:
h.logger.Debug("stripe event ignored", "type", event.Type)
}
w.WriteHeader(http.StatusOK)
}
// applySubscription patches the subscriptions row keyed by Stripe
// customer id. Best-effort — failures here are logged but don't NACK
// the webhook (Stripe would retry forever and the row would never
// converge).
func (h *stripeWebhookHandler) applySubscription(r *http.Request, event stripe.Event) {
var sub stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
// Some invoice events carry an Invoice payload — try to extract
// the subscription id from there and short-circuit on status.
h.logger.Debug("stripe webhook: not a subscription payload", "type", event.Type, "err", err)
return
}
if sub.Customer == nil || sub.Customer.ID == "" {
h.logger.Warn("stripe webhook: subscription has no customer", "subscription", sub.ID)
return
}
tier := tierFromSubscription(&sub)
status := string(sub.Status)
cancelAtPeriodEnd := sub.CancelAtPeriodEnd
// As of API 2024-10-28, current_period_end lives on the subscription
// item, not the subscription. We pick the earliest item's end — for
// single-item subscriptions (our case) that's the canonical one.
var periodEnd *time.Time
for _, item := range sub.Items.Data {
if item.CurrentPeriodEnd > 0 {
t := time.Unix(item.CurrentPeriodEnd, 0).UTC()
if periodEnd == nil || t.Before(*periodEnd) {
periodEnd = &t
}
}
}
subID := sub.ID
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
StripeSubscriptionID: &subID,
Tier: stringPtr(string(tier)),
Status: &status,
CurrentPeriodEnd: periodEnd,
CancelAtPeriodEnd: &cancelAtPeriodEnd,
}); err != nil {
h.logger.Error("stripe webhook: update subscription failed", "err", err)
}
}
func (h *stripeWebhookHandler) applySubscriptionDeleted(r *http.Request, event stripe.Event) {
var sub stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
h.logger.Warn("stripe webhook: bad deleted payload", "err", err)
return
}
if sub.Customer == nil {
return
}
status := "canceled"
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
Status: &status,
}); err != nil {
h.logger.Error("stripe webhook: mark canceled failed", "err", err)
}
}
// tierFromSubscription inspects the Stripe price metadata to figure out
// which GuestGuard tier this subscription corresponds to. We read a
// price-level metadata key `gg_tier` (set in the Stripe dashboard when
// you create the Price). Fallback: free.
func tierFromSubscription(sub *stripe.Subscription) billing.Tier {
if sub == nil || len(sub.Items.Data) == 0 {
return billing.TierFree
}
for _, item := range sub.Items.Data {
if item.Price == nil {
continue
}
if v, ok := item.Price.Metadata["gg_tier"]; ok {
t := billing.Tier(v)
if t.Valid() {
return t
}
}
// Heuristic fallback for tests / unconfigured prices: look at the
// recurring interval and amount tier.
if item.Price.Recurring != nil && item.Price.UnitAmount >= 19900 {
return billing.TierBusiness
}
if item.Price.Recurring != nil && item.Price.UnitAmount >= 4900 {
return billing.TierPro
}
}
return billing.TierFree
}
func stringPtr(s string) *string { return &s }
+329 -14
View File
@@ -2,6 +2,7 @@ package api
import (
"context"
"encoding/json"
"errors"
"log/slog"
"net/http"
@@ -20,25 +21,38 @@ type accessPublisher interface {
PublishAccessAttempted(ctx context.Context, evt natspub.AccessAttempted) error
}
type invitationPublisher interface {
PublishInvitationSend(ctx context.Context, evt natspub.InvitationSend) error
}
type tokenHandler struct {
logger *slog.Logger
guests *storage.GuestRepo
tokens *storage.TokenRepo
events *storage.EventRepo
accessLogs *storage.AccessLogRepo
gen *auth.Generator
ttl time.Duration
pub accessPublisher
logger *slog.Logger
guests *storage.GuestRepo
tokens *storage.TokenRepo
events *storage.EventRepo
users *storage.UserRepo
accessLogs *storage.AccessLogRepo
gen *auth.Generator
ttl time.Duration
pub accessPublisher
invitations invitationPublisher
publicBaseURL string
}
type issueTokenResponse struct {
Token string `json:"token"`
TokenID uuid.UUID `json:"token_id"`
Meta *domain.Token `json:"meta"`
Token string `json:"token"`
TokenID uuid.UUID `json:"token_id"`
Meta *domain.Token `json:"meta"`
InvitationQueued bool `json:"invitation_queued"`
InvitationLink string `json:"invitation_link"`
}
// POST /events/{id}/guests/{guest_id}/tokens — issue a token for the guest.
func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
@@ -47,6 +61,10 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
if !ok {
return
}
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
if !ok {
return
}
guest, err := h.guests.Get(r.Context(), guestID)
if err != nil {
@@ -78,10 +96,307 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
return
}
link := h.invitationLink(raw)
invitationQueued := h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
writeJSON(w, http.StatusCreated, issueTokenResponse{
Token: raw,
TokenID: tk.ID,
Meta: tk,
Token: raw,
TokenID: tk.ID,
Meta: tk,
InvitationQueued: invitationQueued,
InvitationLink: link,
})
}
// queueInvitation publishes an invitation.send event so the notifier can
// dispatch a branded email. Best-effort: if any step fails we log and
// return false rather than failing the whole token-issue request — the
// host still has the raw URL in the response and can re-trigger sending.
func (h *tokenHandler) queueInvitation(
ctx context.Context,
event *domain.Event,
guest *domain.Guest,
tk *domain.Token,
hostID uuid.UUID,
rawToken string,
) bool {
if h.invitations == nil {
return false
}
if guest.Email == nil || *guest.Email == "" {
// Phone-only / nameless guests get no email — host shares the link
// manually. Show that on the UI so it's not a silent surprise.
return false
}
hostName := ""
if h.users != nil {
if host, err := h.users.GetByID(ctx, hostID); err == nil && host != nil {
hostName = host.Name
}
}
evt := natspub.InvitationSend{
EventID: event.ID,
GuestID: guest.ID,
TokenID: tk.ID,
GuestName: guest.Name,
GuestEmail: *guest.Email,
HostName: hostName,
EventName: event.Name,
Venue: event.Venue,
EventDate: event.EventDate,
Link: h.invitationLink(rawToken),
IssuedAt: time.Now().UTC(),
}
if err := h.invitations.PublishInvitationSend(ctx, evt); err != nil {
h.logger.Warn("publish invitation.send (continuing)", "err", err, "guest_id", guest.ID)
return false
}
return true
}
// invitationLink renders the public RSVP URL the guest clicks from their
// inbox. publicBaseURL is the externally-reachable host (set via
// GG_PUBLIC_BASE_URL); access via /rsvp/<token> is intentional — the
// frontend page rsvp/[token].vue catches the raw token.
func (h *tokenHandler) invitationLink(raw string) string {
base := h.publicBaseURL
if base == "" {
base = "http://localhost:3000"
}
return base + "/rsvp/" + raw
}
// bulkIssueRequest is the optional JSON body for the bulk-invite call.
// An empty body (or missing GuestIDs) means "every guest on the event
// who doesn't already have a token".
type bulkIssueRequest struct {
GuestIDs []string `json:"guest_ids"`
}
type bulkIssueItemError struct {
GuestID string `json:"guest_id"`
Reason string `json:"reason"`
}
// bulkIssueToken is one minted invitation. The raw token is returned so
// the host's UI can offer a "copy link" affordance after a bulk send
// (especially for guests with no email on file) without making another
// round-trip. Same data the per-guest issue endpoint already exposes,
// scoped to the host who owns the event.
type bulkIssueToken struct {
GuestID string `json:"guest_id"`
Token string `json:"token"`
InvitationQueued bool `json:"invitation_queued"`
InvitationLink string `json:"invitation_link"`
}
type bulkIssueResponse struct {
Issued int `json:"issued"`
Queued int `json:"queued"`
SkippedExisting int `json:"skipped_existing"`
SkippedNoEmail int `json:"skipped_no_email"`
Tokens []bulkIssueToken `json:"tokens,omitempty"`
Errors []bulkIssueItemError `json:"errors,omitempty"`
}
// POST /events/{id}/guests/invitations/bulk — generate tokens for every
// eligible guest (or the explicit subset) on the event and queue an
// invitation email for those with an address. Best-effort: any per-guest
// error is reported in the response and doesn't abort the rest.
func (h *tokenHandler) bulkIssue(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
if !ok {
return
}
var req bulkIssueRequest
if r.ContentLength > 0 {
// Body is optional; only decode when something was actually sent.
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
}
var onlyIDs []uuid.UUID
if len(req.GuestIDs) > 0 {
onlyIDs = make([]uuid.UUID, 0, len(req.GuestIDs))
for _, raw := range req.GuestIDs {
id, err := uuid.Parse(raw)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid guest id: "+raw)
return
}
onlyIDs = append(onlyIDs, id)
}
}
guests, err := h.guests.ListGuestsForInvitation(r.Context(), eventID, onlyIDs)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to load guests")
return
}
hostName := ""
if h.users != nil {
if host, err := h.users.GetByID(r.Context(), hostID); err == nil && host != nil {
hostName = host.Name
}
}
resp := bulkIssueResponse{}
for _, g := range guests {
if g.HasToken {
resp.SkippedExisting++
continue
}
raw, hash, err := h.gen.Generate()
if err != nil {
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "mint token failed"})
continue
}
tk, err := h.tokens.Create(r.Context(), storage.CreateTokenParams{
GuestID: g.ID,
TokenHash: hash,
ExpiresAt: time.Now().UTC().Add(h.ttl),
})
if err != nil {
// Likely a race against the unique constraint (someone else
// issued in parallel) — surface but don't fail the batch.
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: err.Error()})
continue
}
resp.Issued++
link := h.invitationLink(raw)
tokenInfo := bulkIssueToken{
GuestID: g.ID.String(),
Token: raw,
InvitationLink: link,
}
if g.Email == "" {
resp.SkippedNoEmail++
resp.Tokens = append(resp.Tokens, tokenInfo)
continue
}
evt := natspub.InvitationSend{
EventID: event.ID,
GuestID: g.ID,
TokenID: tk.ID,
GuestName: g.Name,
GuestEmail: g.Email,
HostName: hostName,
EventName: event.Name,
Venue: event.Venue,
EventDate: event.EventDate,
Link: h.invitationLink(raw),
IssuedAt: time.Now().UTC(),
}
if h.invitations == nil {
resp.Tokens = append(resp.Tokens, tokenInfo)
continue
}
if err := h.invitations.PublishInvitationSend(r.Context(), evt); err != nil {
h.logger.Warn("publish invitation.send (bulk, continuing)", "err", err, "guest_id", g.ID)
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "publish failed"})
resp.Tokens = append(resp.Tokens, tokenInfo)
continue
}
resp.Queued++
tokenInfo.InvitationQueued = true
resp.Tokens = append(resp.Tokens, tokenInfo)
}
writeJSON(w, http.StatusOK, resp)
}
type rotateTokenRequest struct {
// SendEmail asks the notifier to re-deliver the invitation. False
// means "just give me a fresh link" — typical for phone-only guests
// where the host shares the new URL via SMS.
SendEmail bool `json:"send_email"`
}
// POST /events/{id}/guests/{guest_id}/tokens/rotate — invalidate the
// guest's existing invitation link and mint a fresh one. Optionally
// re-publishes invitation.send so the notifier re-delivers via email.
// The old URL stops working immediately.
func (h *tokenHandler) rotate(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
guestID, ok := parseIDParam(w, r, "guest_id")
if !ok {
return
}
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
if !ok {
return
}
guest, err := h.guests.Get(r.Context(), guestID)
if err != nil {
if errors.Is(err, domain.ErrGuestNotFound) {
writeError(w, http.StatusNotFound, "guest not found")
return
}
writeError(w, http.StatusInternalServerError, "failed to load guest")
return
}
if guest.EventID != eventID {
writeError(w, http.StatusNotFound, "guest not found in event")
return
}
var req rotateTokenRequest
if r.ContentLength > 0 {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
}
raw, hash, err := h.gen.Generate()
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to generate token")
return
}
tk, err := h.tokens.RotateForGuest(r.Context(), storage.CreateTokenParams{
GuestID: guestID,
TokenHash: hash,
ExpiresAt: time.Now().UTC().Add(h.ttl),
})
if err != nil {
h.logger.Error("rotate token", "err", err, "guest_id", guestID)
writeError(w, http.StatusInternalServerError, "failed to rotate token")
return
}
link := h.invitationLink(raw)
invitationQueued := false
if req.SendEmail {
invitationQueued = h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
}
writeJSON(w, http.StatusOK, issueTokenResponse{
Token: raw,
TokenID: tk.ID,
Meta: tk,
InvitationQueued: invitationQueued,
InvitationLink: link,
})
}
+50
View File
@@ -0,0 +1,50 @@
package api
import (
"log/slog"
"net/http"
"github.com/alchemistkay/guestguard/internal/notification"
)
type unsubscribeHandler struct {
logger *slog.Logger
signer *notification.UnsubscribeSigner
suppress *notification.SuppressionRepo
}
// GET /unsubscribe/{token} — surface the email address that the token
// belongs to so the frontend can show a confirmation page. Honoured even
// before the user clicks "Confirm" so they see what's being unsubscribed.
func (h *unsubscribeHandler) preview(w http.ResponseWriter, r *http.Request) {
if h.signer == nil {
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
return
}
email, err := h.signer.Verify(r.PathValue("token"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
return
}
writeJSON(w, http.StatusOK, map[string]string{"email": email})
}
// POST /unsubscribe/{token} — add the email to the suppression list.
// Idempotent: clicking the link twice keeps the existing entry.
func (h *unsubscribeHandler) confirm(w http.ResponseWriter, r *http.Request) {
if h.signer == nil || h.suppress == nil {
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
return
}
email, err := h.signer.Verify(r.PathValue("token"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
return
}
if err := h.suppress.Add(r.Context(), email, "user clicked unsubscribe", notification.SuppressionUser); err != nil {
h.logger.Error("add suppression", "err", err, "email", email)
writeError(w, http.StatusInternalServerError, "failed to unsubscribe")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "unsubscribed", "email": email})
}
-55
View File
@@ -1,55 +0,0 @@
package api
import (
"encoding/json"
"errors"
"net/http"
"net/mail"
"github.com/alchemistkay/guestguard/internal/domain"
"github.com/alchemistkay/guestguard/internal/storage"
)
type userHandler struct {
repo *storage.UserRepo
}
type upsertUserRequest struct {
Email string `json:"email"`
Name string `json:"name"`
}
// POST /users — idempotent: returns the existing user if the email already
// exists, creates one otherwise. This keeps the demo flow simple without
// requiring real auth.
func (h *userHandler) upsert(w http.ResponseWriter, r *http.Request) {
var req upsertUserRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
if _, err := mail.ParseAddress(req.Email); err != nil {
writeError(w, http.StatusBadRequest, "email is invalid")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
u, err := h.repo.Create(r.Context(), req.Email, req.Name)
if err == nil {
writeJSON(w, http.StatusCreated, u)
return
}
if errors.Is(err, domain.ErrEmailTaken) {
existing, getErr := h.repo.GetByEmail(r.Context(), req.Email)
if getErr != nil {
writeError(w, http.StatusInternalServerError, "failed to load user")
return
}
writeJSON(w, http.StatusOK, existing)
return
}
writeError(w, http.StatusInternalServerError, "failed to create user")
}
+145
View File
@@ -0,0 +1,145 @@
package api
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"github.com/alchemistkay/guestguard/internal/notification"
)
// webhookHandler accepts provider status notifications and reflects them
// onto the notifications table + suppression list.
//
// Signature verification is intentionally a TODO until the user provisions
// real Twilio + SES creds — verifying against test fixtures alone would
// give a false sense of security. The endpoint is therefore *not* exposed
// publicly until the deployment is ready.
type webhookHandler struct {
logger *slog.Logger
notifs *notification.Repo
suppress *notification.SuppressionRepo
}
// POST /webhooks/twilio/status — Twilio status callback (form-encoded).
// Fields we care about: MessageSid, MessageStatus (sent|delivered|
// undelivered|failed), ErrorCode, To.
func (h *webhookHandler) twilio(w http.ResponseWriter, r *http.Request) {
if h.notifs == nil {
w.WriteHeader(http.StatusNoContent)
return
}
// TODO(blockD2): verify X-Twilio-Signature with GG_TWILIO_AUTH_TOKEN.
if err := r.ParseForm(); err != nil {
writeError(w, http.StatusBadRequest, "invalid form")
return
}
sid := r.PostForm.Get("MessageSid")
status := r.PostForm.Get("MessageStatus")
if sid == "" || status == "" {
writeError(w, http.StatusBadRequest, "missing MessageSid / MessageStatus")
return
}
ctx := r.Context()
switch status {
case "delivered":
_ = h.notifs.MarkDelivered(ctx, sid)
case "undelivered", "failed":
_ = h.notifs.MarkBounce(ctx, sid, "permanent")
}
h.logger.Info("twilio status callback", "sid", sid, "status", status)
w.WriteHeader(http.StatusNoContent)
}
// POST /webhooks/ses/notifications — SNS-delivered SES notification (JSON).
// Handles the two shapes SES uses: bounce + complaint events. Each event
// carries the messageId we stored in provider_message_id and an array of
// affected recipients.
func (h *webhookHandler) ses(w http.ResponseWriter, r *http.Request) {
if h.notifs == nil {
w.WriteHeader(http.StatusNoContent)
return
}
// TODO(blockD2): verify SNS signature using the cert URL field.
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
writeError(w, http.StatusBadRequest, "read body")
return
}
defer r.Body.Close()
var envelope struct {
Type string `json:"Type"` // "Notification" | "SubscriptionConfirmation"
Message string `json:"Message"` // stringified JSON for Notification
}
if err := json.Unmarshal(body, &envelope); err != nil {
writeError(w, http.StatusBadRequest, "invalid json envelope")
return
}
if envelope.Type == "SubscriptionConfirmation" {
// Confirmed by visiting SubscribeURL — manual op-side step.
h.logger.Info("ses subscription confirmation received (manual confirm required)")
w.WriteHeader(http.StatusNoContent)
return
}
if envelope.Message == "" {
w.WriteHeader(http.StatusNoContent)
return
}
var inner struct {
NotificationType string `json:"notificationType"` // "Bounce" | "Complaint" | "Delivery"
Mail struct {
MessageID string `json:"messageId"`
} `json:"mail"`
Bounce struct {
BounceType string `json:"bounceType"` // "Permanent" | "Transient"
BouncedRecipients []struct {
EmailAddress string `json:"emailAddress"`
} `json:"bouncedRecipients"`
} `json:"bounce"`
Complaint struct {
ComplainedRecipients []struct {
EmailAddress string `json:"emailAddress"`
} `json:"complainedRecipients"`
} `json:"complaint"`
}
if err := json.Unmarshal([]byte(envelope.Message), &inner); err != nil {
writeError(w, http.StatusBadRequest, "invalid inner json")
return
}
ctx := r.Context()
switch inner.NotificationType {
case "Bounce":
bt := "transient"
if inner.Bounce.BounceType == "Permanent" {
bt = "permanent"
}
_ = h.notifs.MarkBounce(ctx, inner.Mail.MessageID, bt)
if h.suppress != nil && bt == "permanent" {
for _, rcp := range inner.Bounce.BouncedRecipients {
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses permanent bounce", notification.SuppressionBounce)
}
}
case "Complaint":
_ = h.notifs.MarkComplaint(ctx, inner.Mail.MessageID)
if h.suppress != nil {
for _, rcp := range inner.Complaint.ComplainedRecipients {
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses complaint", notification.SuppressionComplaint)
}
}
case "Delivery":
_ = h.notifs.MarkDelivered(ctx, inner.Mail.MessageID)
}
h.logger.Info("ses notification", "type", inner.NotificationType, "message_id", inner.Mail.MessageID)
w.WriteHeader(http.StatusNoContent)
}
// Compile-time check that ctx is unused in package — silences linter on
// some Go versions when the file would otherwise import context only for
// the handler signatures.
var _ = context.Background
+51
View File
@@ -0,0 +1,51 @@
package api
import (
"encoding/json"
"net/http"
"time"
"github.com/alchemistkay/guestguard/internal/storage"
)
type wsTicketHandler struct {
tickets *wsTicketStore
events *storage.EventRepo
}
type wsTicketResponse struct {
Ticket string `json:"ticket"`
ExpiresAt time.Time `json:"expires_at"`
}
// POST /auth/ws-ticket — requireAuth-protected; body { "event_id": "<uuid>" }.
// Returns a single-use ticket valid for ~60 seconds. The frontend appends it
// as `?ticket=…` on the WebSocket URL.
func (h *wsTicketHandler) issue(w http.ResponseWriter, r *http.Request) {
hostID, ok := hostFromContext(w, r)
if !ok {
return
}
var req struct {
EventID string `json:"event_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid json")
return
}
eventID, ok := parseRawUUID(w, "event_id", req.EventID)
if !ok {
return
}
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
return
}
tok, exp, err := h.tickets.Mint(hostID, eventID)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to mint ticket")
return
}
writeJSON(w, http.StatusOK, wsTicketResponse{Ticket: tok, ExpiresAt: exp})
}
+81
View File
@@ -0,0 +1,81 @@
package api
import (
"crypto/rand"
"encoding/base64"
"sync"
"time"
"github.com/google/uuid"
)
// wsTicketStore mints short-lived single-use tickets that authorise a
// WebSocket handshake. The plan calls this option 3 in Block B: cookies
// don't reach the WS handshake on cross-origin setups and a JWT in the URL
// would leak to logs; a one-shot ticket sidesteps both.
//
// Block B keeps this in-process. When the API runs more than one replica
// this needs to move to Redis (Block C territory).
type wsTicketStore struct {
mu sync.Mutex
entries map[string]wsTicketEntry
ttl time.Duration
now func() time.Time
}
type wsTicketEntry struct {
userID uuid.UUID
eventID uuid.UUID
expiresAt time.Time
}
func newWSTicketStore(ttl time.Duration) *wsTicketStore {
return &wsTicketStore{
entries: make(map[string]wsTicketEntry),
ttl: ttl,
now: time.Now,
}
}
// Mint returns a fresh URL-safe ticket bound to userID + eventID.
func (s *wsTicketStore) Mint(userID, eventID uuid.UUID) (string, time.Time, error) {
buf := make([]byte, 24)
if _, err := rand.Read(buf); err != nil {
return "", time.Time{}, err
}
tok := base64.RawURLEncoding.EncodeToString(buf)
exp := s.now().Add(s.ttl)
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked()
s.entries[tok] = wsTicketEntry{userID: userID, eventID: eventID, expiresAt: exp}
return tok, exp, nil
}
// Consume removes the ticket and returns the bound (userID, eventID) if it
// was valid. A ticket is single-use — replaying it fails.
func (s *wsTicketStore) Consume(token string) (uuid.UUID, uuid.UUID, bool) {
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.entries[token]
if !ok {
return uuid.Nil, uuid.Nil, false
}
delete(s.entries, token)
if s.now().After(entry.expiresAt) {
return uuid.Nil, uuid.Nil, false
}
return entry.userID, entry.eventID, true
}
// sweepLocked drops expired entries opportunistically. Cheap because we
// usually only hold dozens of tickets at a time.
func (s *wsTicketStore) sweepLocked() {
now := s.now()
for k, v := range s.entries {
if now.After(v.expiresAt) {
delete(s.entries, k)
}
}
}
+23 -3
View File
@@ -89,16 +89,36 @@ func (h *Hub) remove(eventID uuid.UUID, s *subscriber) {
}
type wsHandler struct {
logger *slog.Logger
hub *Hub
logger *slog.Logger
hub *Hub
tickets *wsTicketStore
}
// GET /ws/events/{id} — dashboard live feed for one event.
// GET /ws/events/{id}?ticket=... — dashboard live feed for one event.
//
// The handshake is authorised by a single-use ticket minted via
// POST /auth/ws-ticket (option 3 from the Block B plan). The ticket binds
// the connecting user to a specific event_id; we reject if either is
// missing or doesn't match the URL path.
func (h *wsHandler) handle(w http.ResponseWriter, r *http.Request) {
eventID, ok := parseIDParam(w, r, "id")
if !ok {
return
}
rawTicket := r.URL.Query().Get("ticket")
if rawTicket == "" {
writeError(w, http.StatusUnauthorized, "missing ticket")
return
}
_, ticketEventID, valid := h.tickets.Consume(rawTicket)
if !valid {
writeError(w, http.StatusUnauthorized, "invalid or expired ticket")
return
}
if ticketEventID != eventID {
writeError(w, http.StatusForbidden, "ticket does not match event")
return
}
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
// In dev the frontend runs on a different origin (localhost:3000 → localhost:8080).