feat: ship Tier 1 — auth, authz, rate limits, real notifications, CSV import, billing, backups/DR, privacy
Closes every block in docs/TIER1_PLAN.md from the Claude-scope side. The
homelab / cloud setup steps (SES verification, restore drill, lawyer-
drafted ToS) remain operator-owned but are unblocked.
Block A — Authentication
- Migration 0003: password_hash, email_verified, email_verification_tokens,
password_reset_tokens, refresh_tokens (with replaced_by family chain).
- Bcrypt hasher, HS256 JWT signer, single-use refresh tokens with rotation
+ replay-detection (revokes the family on reuse).
- /auth/signup, /login, /refresh, /logout, /verify-email,
/forgot-password, /reset-password — enumeration-safe.
- requireAuth middleware + GET /me.
- Frontend useAuth/useApi with auto-refresh-on-401, login/signup/verify/
forgot/reset pages, route-guard middleware.
Block B — Authorisation
- EventRepo.GetForHost; Update/Delete scoped by host_id.
- All host routes behind requireAuth + ownership; cross-tenant returns
404 (no enumeration). ?host_id removed.
- WS auth via short-lived single-use tickets (POST /auth/ws-ticket).
- Tests: TestCrossTenantIsolation — 9 probes.
Block C — Rate limiting
- Redis sliding-window via Lua (atomic ZADD+ZCARD+PEXPIRE).
- Per-route limits matching the plan (signup IP, login IP+email, RSVP/
access by token, events/guests/tokens by user_id).
- 429 with Retry-After header and JSON body.
- Auth lockout: 5 failed logins → account locked, only password reset
clears it.
- Frontend: useErrMessage normalises 429 + locked messaging.
Block D — Real notifications
- Migration 0004: provider_message_id, bounce_type, complained columns
+ unsubscribes (CITEXT) suppression table.
- Branded HTML + plaintext templates for verification, reset, invitation,
confirmation, reminder. Per-page templates avoid html/template's
contextual-escape collisions.
- Senders: SESv2, Twilio (SMS), SMTP (Mailpit-friendly), Resend HTTP.
- PickEmailSender priority Resend > SMTP > SES > Log — system boots
cleanly in dev with Mailpit; production flips one env var.
- Webhook endpoints (Twilio status + SES SNS) — bounces add to suppression;
signature verification stubbed pending creds.
- Auto-send: POST /tokens publishes invitation.send; notifier renders +
delivers via the configured backend; suppression list honoured.
- Bulk + per-row invitation flow: POST /events/{id}/guests/invitations/bulk
returns per-guest tokens so phone-only guests can be SMS'd manually.
- Unsubscribe: signed HMAC token (no TTL) + /unsubscribe/[token] page.
- WhatsApp Option A+: wa.me click-to-chat wizard with per-guest progress
tracking, isLikelyE164 validation, edit-from-wizard.
- Token rotate (POST /tokens/rotate) invalidates the old URL — used by
the regenerate-link flow.
- Mailpit added to docker-compose for dev inbox.
Block E — CSV import
- Streaming parser: tolerant header detection, UTF-8 BOM + UTF-16 LE/BE
decoding, row-level validation, 5,000-row cap.
- Strict E.164 phone validation with helpful error message.
- POST /preview + /import + GET /template; preview UI on event page;
atomic per-batch with dedup on existing emails.
Phone capture across UI
- PhoneInput component: country picker (~50 ISO codes) + national input +
live E.164 preview + inline length validation.
- Used in Add Guest and Edit Guest modals. Smart paste-handling extracts
country code from full E.164 strings.
Block F — Billing (Stripe)
- Migration 0005: subscriptions table (user_id → tier/status/period_end +
Stripe customer/sub ids). Partial unique index keeps one granting sub
per user.
- internal/billing: Tier + Limits model (Free 1/50, Pro 10/1000, Business
∞/5000), Stripe SDK wrapper with IgnoreAPIVersionMismatch for newer
account API versions.
- /billing/checkout-session, /billing/portal, /billing/status,
/webhooks/stripe (signature-verified, lifecycle events).
- Tier enforcement: 402 on POST /events, /guests, /import with
{error, reason, tier, used, limit, upgrade_url} body.
- Frontend: useBilling composable, /dashboard/billing page (current plan,
usage bars, tier cards), global UpgradeModal triggered by useApi's
402 interceptor.
- Customer portal kept for self-service cancel/payment-method changes.
Block G — Backups & DR (application side)
- Every migration has a tested .down.sql.
- TestMigrationRoundtrip applies all ups → all downs → all ups against a
fresh container; catches asymmetric down migrations.
- cmd/restore-verify: 28-check post-restore invariant tool (schema
presence, no orphans across 10 FK relationships, email uniqueness,
single-active subscription, row-count snapshot).
- docs/RUNBOOK_RESTORE.md: 9-step restore procedure with RTO/RPO
targets, drill instructions, rollback path.
Block H — Privacy compliance (application side)
- Migration 0006: deleted_at + terms_accepted_at + privacy_policy_accepted_at
on users. Partial index on email for live-only uniqueness.
- GET /me/data-export — synchronous JSON dump (user, events, guests,
tokens, rsvps, access_logs, notifications).
- DELETE /me — soft-delete with PII scrub + refresh-token revocation;
re-signup with same email works.
- POST /me/accept-terms — idempotent consent recording.
- Frontend /privacy + /terms placeholder pages with substantive (pending
legal review) copy; footer links; signup terms checkbox; TermsGateModal
for accounts created before the rollout; export + delete buttons on
/dashboard/billing.
Tests
- All migrations verified up/down/up.
- Integration suite: TestE2EHappyPath, TestAuthFlow, TestCrossTenantIsolation,
TestRateLimitSignup, TestLoginLockout, TestUnsubscribeFlow,
TestSESBounceWebhook, TestTwilioStatusWebhook, TestCsvImportFlow,
TestCsvImportAtomicRollback, TestBulkIssueInvitations, TestBulkIssueExplicitSubset,
TestTokenIssuePublishesInvitation, TestTokenIssueWithoutGuestEmailSkipsInvitation,
TestGuestUpdate, TestGuestDelete, TestTokenRotate, TestSMTPSenderAgainstMailpit,
TestFreeTierEventLimit, TestFreeTierGuestLimit, TestBusinessTierBypassesLimits,
TestDataExport, TestDeleteMe, TestAcceptTerms, TestMigrationRoundtrip.
Full suite runs in ~120s against real Postgres + NATS + Redis + Mailpit.
- Unit suite green across internal/auth, internal/csvimport,
internal/notification, internal/ratelimit, internal/domain.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
@@ -42,16 +40,15 @@ type activityItem struct {
|
||||
// for an event, sorted newest first. Frontends use this on dashboard mount
|
||||
// to backfill the live monitor with history.
|
||||
func (h *activityHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := h.events.Get(r.Context(), eventID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,557 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/ratelimit"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
const refreshCookieName = "gg_refresh"
|
||||
|
||||
type authHandler struct {
|
||||
logger *slog.Logger
|
||||
users *storage.UserRepo
|
||||
verifications *storage.EmailVerificationRepo
|
||||
resets *storage.PasswordResetRepo
|
||||
refreshes *storage.RefreshTokenRepo
|
||||
hasher *auth.PasswordHasher
|
||||
signer *auth.JWTSigner
|
||||
emails auth.EmailSender
|
||||
lockout *auth.LockoutTracker
|
||||
limiter *ratelimit.Limiter
|
||||
|
||||
publicBaseURL string
|
||||
emailVerificationTTL time.Duration
|
||||
passwordResetTTL time.Duration
|
||||
refreshTTL time.Duration
|
||||
cookieDomain string
|
||||
cookieSecure bool
|
||||
}
|
||||
|
||||
type authHandlerDeps struct {
|
||||
Logger *slog.Logger
|
||||
Users *storage.UserRepo
|
||||
Verifications *storage.EmailVerificationRepo
|
||||
Resets *storage.PasswordResetRepo
|
||||
Refreshes *storage.RefreshTokenRepo
|
||||
Hasher *auth.PasswordHasher
|
||||
Signer *auth.JWTSigner
|
||||
Emails auth.EmailSender
|
||||
Lockout *auth.LockoutTracker
|
||||
Limiter *ratelimit.Limiter
|
||||
|
||||
PublicBaseURL string
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
RefreshTTL time.Duration
|
||||
CookieDomain string
|
||||
CookieSecure bool
|
||||
}
|
||||
|
||||
func newAuthHandler(d authHandlerDeps) *authHandler {
|
||||
return &authHandler{
|
||||
logger: d.Logger,
|
||||
users: d.Users,
|
||||
verifications: d.Verifications,
|
||||
resets: d.Resets,
|
||||
refreshes: d.Refreshes,
|
||||
hasher: d.Hasher,
|
||||
signer: d.Signer,
|
||||
emails: d.Emails,
|
||||
lockout: d.Lockout,
|
||||
limiter: d.Limiter,
|
||||
publicBaseURL: strings.TrimRight(d.PublicBaseURL, "/"),
|
||||
emailVerificationTTL: d.EmailVerificationTTL,
|
||||
passwordResetTTL: d.PasswordResetTTL,
|
||||
refreshTTL: d.RefreshTTL,
|
||||
cookieDomain: d.CookieDomain,
|
||||
cookieSecure: d.CookieSecure,
|
||||
}
|
||||
}
|
||||
|
||||
// --- request/response types ---
|
||||
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
AcceptTerms bool `json:"accept_terms"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type verifyEmailRequest struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type forgotPasswordRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type resetPasswordRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
type authSuccess struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
User *domain.User `json:"user"`
|
||||
}
|
||||
|
||||
// --- handlers ---
|
||||
|
||||
// POST /auth/signup
|
||||
func (h *authHandler) signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "email is invalid")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
hash, err := h.hasher.Hash(req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("hash password", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.Create(r.Context(), storage.CreateUserParams{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
PasswordHash: hash,
|
||||
AcceptTerms: req.AcceptTerms,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEmailTaken) {
|
||||
// Don't leak which addresses are registered. Still return 201 and
|
||||
// trigger a "if-you-already-have-an-account" email asynchronously
|
||||
// (skipped for the stub). On real auth this should send a "you
|
||||
// tried to sign up again, here's a reset link" email.
|
||||
h.logger.Info("signup attempted with existing email", "email", req.Email)
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
|
||||
return
|
||||
}
|
||||
h.logger.Error("create user", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.sendVerificationEmail(r.Context(), u); err != nil {
|
||||
h.logger.Error("send verification email", "err", err, "user_id", u.ID)
|
||||
// Don't fail the signup — user can request a resend.
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, map[string]string{"status": "verification_sent"})
|
||||
}
|
||||
|
||||
// POST /auth/login
|
||||
func (h *authHandler) login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Email == "" || req.Password == "" {
|
||||
writeError(w, http.StatusBadRequest, "email and password required")
|
||||
return
|
||||
}
|
||||
|
||||
// Per-(IP + email) sliding-window — 10 per 5 minutes per the plan.
|
||||
if !h.checkRate(w, r, "login", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
|
||||
10, 5*time.Minute) {
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.GetByEmail(r.Context(), req.Email)
|
||||
if err != nil || u.PasswordHash == "" {
|
||||
_, _ = h.lockout.RecordFailure(r.Context(), req.Email, nil)
|
||||
writeError(w, http.StatusUnauthorized, "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
// If the account is already locked, reject before doing a bcrypt compare.
|
||||
locked, _ := h.lockout.IsLocked(r.Context(), u.ID)
|
||||
if locked {
|
||||
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.hasher.Verify(u.PasswordHash, req.Password); err != nil {
|
||||
locked, _ := h.lockout.RecordFailure(r.Context(), req.Email, &u.ID)
|
||||
if locked {
|
||||
writeError(w, http.StatusForbidden, "account locked — reset your password to unlock")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid email or password")
|
||||
return
|
||||
}
|
||||
if !u.EmailVerified {
|
||||
writeError(w, http.StatusForbidden, "email not verified")
|
||||
return
|
||||
}
|
||||
|
||||
h.lockout.ClearOnSuccess(r.Context(), req.Email)
|
||||
if err := h.issueSession(w, r, u); err != nil {
|
||||
h.logger.Error("issue session", "err", err, "user_id", u.ID)
|
||||
writeError(w, http.StatusInternalServerError, "failed to start session")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// checkRate consults the limiter (when one is configured) and writes a 429
|
||||
// response if the budget is exhausted. Returns false if the caller should
|
||||
// stop handling the request.
|
||||
func (h *authHandler) checkRate(w http.ResponseWriter, r *http.Request, name, key string, limit int, window time.Duration) bool {
|
||||
if h.limiter == nil || key == "" {
|
||||
return true
|
||||
}
|
||||
res, err := h.limiter.Allow(r.Context(), name, key, limit, window)
|
||||
if err != nil {
|
||||
h.logger.Warn("ratelimit error (failing open)", "rule", name, "err", err)
|
||||
return true
|
||||
}
|
||||
if !res.Allowed {
|
||||
retry := int(res.RetryAfter.Round(time.Second).Seconds())
|
||||
if retry < 1 {
|
||||
retry = 1
|
||||
}
|
||||
w.Header().Set("Retry-After", strconv.Itoa(retry))
|
||||
writeJSON(w, http.StatusTooManyRequests, map[string]any{
|
||||
"error": "rate limit exceeded",
|
||||
"retry_after": retry,
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// POST /auth/refresh
|
||||
func (h *authHandler) refresh(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(refreshCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing refresh token")
|
||||
return
|
||||
}
|
||||
oldHash := auth.HashOpaque(cookie.Value)
|
||||
|
||||
existing, err := h.refreshes.Get(r.Context(), oldHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrAuthTokenNotFound) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "invalid refresh token")
|
||||
return
|
||||
}
|
||||
h.logger.Error("lookup refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
if existing.RevokedAt != nil {
|
||||
// Replay of a revoked token. Revoke the family.
|
||||
_ = h.refreshes.RevokeAllForUser(r.Context(), existing.UserID)
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token reused")
|
||||
return
|
||||
}
|
||||
if time.Now().After(existing.ExpiresAt) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token expired")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.users.GetByID(r.Context(), existing.UserID)
|
||||
if err != nil {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
newRaw, newHash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
h.logger.Error("mint refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
exp := time.Now().Add(h.refreshTTL)
|
||||
if err := h.refreshes.Rotate(r.Context(), oldHash, storage.CreateRefreshTokenParams{
|
||||
Hash: newHash,
|
||||
UserID: u.ID,
|
||||
ExpiresAt: exp,
|
||||
UserAgent: r.UserAgent(),
|
||||
IPAddress: clientIP(r),
|
||||
}); err != nil {
|
||||
if errors.Is(err, domain.ErrRefreshTokenRevoked) {
|
||||
h.clearRefreshCookie(w)
|
||||
writeError(w, http.StatusUnauthorized, "refresh token reused")
|
||||
return
|
||||
}
|
||||
h.logger.Error("rotate refresh", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
|
||||
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
|
||||
if err != nil {
|
||||
h.logger.Error("sign access", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "refresh failed")
|
||||
return
|
||||
}
|
||||
h.setRefreshCookie(w, newRaw, exp)
|
||||
writeJSON(w, http.StatusOK, authSuccess{
|
||||
AccessToken: access,
|
||||
ExpiresAt: accessExp,
|
||||
User: u,
|
||||
})
|
||||
}
|
||||
|
||||
// POST /auth/logout
|
||||
func (h *authHandler) logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(refreshCookieName)
|
||||
if err == nil && cookie.Value != "" {
|
||||
_ = h.refreshes.Revoke(r.Context(), auth.HashOpaque(cookie.Value))
|
||||
}
|
||||
h.clearRefreshCookie(w)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /auth/verify-email
|
||||
func (h *authHandler) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
||||
var req verifyEmailRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token required")
|
||||
return
|
||||
}
|
||||
uid, err := h.verifications.Consume(r.Context(), auth.HashOpaque(req.Token))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrAuthTokenNotFound):
|
||||
writeError(w, http.StatusBadRequest, "invalid token")
|
||||
case errors.Is(err, domain.ErrAuthTokenConsumed):
|
||||
writeError(w, http.StatusBadRequest, "token already used")
|
||||
case errors.Is(err, domain.ErrAuthTokenExpired):
|
||||
writeError(w, http.StatusBadRequest, "token expired")
|
||||
default:
|
||||
h.logger.Error("consume verification", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "verification failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.users.MarkEmailVerified(r.Context(), uid); err != nil {
|
||||
h.logger.Error("mark verified", "err", err, "user_id", uid)
|
||||
writeError(w, http.StatusInternalServerError, "verification failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "verified"})
|
||||
}
|
||||
|
||||
// POST /auth/forgot-password
|
||||
func (h *authHandler) forgotPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req forgotPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if !h.checkRate(w, r, "forgot_password", clientIP(r)+"|"+strings.ToLower(strings.TrimSpace(req.Email)),
|
||||
3, time.Hour) {
|
||||
return
|
||||
}
|
||||
// Always respond 202 to avoid leaking whether the email exists.
|
||||
defer func() { writeJSON(w, http.StatusAccepted, map[string]string{"status": "if_known_email_sent"}) }()
|
||||
|
||||
u, err := h.users.GetByEmail(r.Context(), req.Email)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
h.logger.Error("mint reset", "err", err)
|
||||
return
|
||||
}
|
||||
exp := time.Now().Add(h.passwordResetTTL)
|
||||
if err := h.resets.Create(r.Context(), u.ID, hash, exp); err != nil {
|
||||
h.logger.Error("persist reset", "err", err)
|
||||
return
|
||||
}
|
||||
link := h.publicBaseURL + "/reset-password/" + url.PathEscape(raw)
|
||||
if err := h.emails.SendPasswordReset(r.Context(), u.Email, u.Name, link); err != nil {
|
||||
h.logger.Error("send reset email", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// POST /auth/reset-password
|
||||
func (h *authHandler) resetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req resetPasswordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "token and new_password required")
|
||||
return
|
||||
}
|
||||
newHash, err := h.hasher.Hash(req.NewPassword)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrPasswordTooShort) || errors.Is(err, auth.ErrPasswordTooLong) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("hash password", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
return
|
||||
}
|
||||
uid, err := h.resets.Consume(r.Context(), auth.HashOpaque(req.Token))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrAuthTokenNotFound):
|
||||
writeError(w, http.StatusBadRequest, "invalid token")
|
||||
case errors.Is(err, domain.ErrAuthTokenConsumed):
|
||||
writeError(w, http.StatusBadRequest, "token already used")
|
||||
case errors.Is(err, domain.ErrAuthTokenExpired):
|
||||
writeError(w, http.StatusBadRequest, "token expired")
|
||||
default:
|
||||
h.logger.Error("consume reset", "err", err)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.users.UpdatePasswordHash(r.Context(), uid, newHash); err != nil {
|
||||
h.logger.Error("update password", "err", err, "user_id", uid)
|
||||
writeError(w, http.StatusInternalServerError, "reset failed")
|
||||
return
|
||||
}
|
||||
// Invalidate all existing sessions.
|
||||
_ = h.refreshes.RevokeAllForUser(r.Context(), uid)
|
||||
// Resetting the password is the canonical "unlock" path for the
|
||||
// account lockout that triggers after repeated bad-credential attempts.
|
||||
if u, err := h.users.GetByID(r.Context(), uid); err == nil {
|
||||
_ = h.lockout.ClearForUser(r.Context(), uid, u.Email)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "password_reset"})
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func (h *authHandler) sendVerificationEmail(ctx context.Context, u *domain.User) error {
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.verifications.Create(ctx, u.ID, hash, time.Now().Add(h.emailVerificationTTL)); err != nil {
|
||||
return err
|
||||
}
|
||||
link := h.publicBaseURL + "/verify-email?token=" + url.QueryEscape(raw)
|
||||
return h.emails.SendVerification(ctx, u.Email, u.Name, link)
|
||||
}
|
||||
|
||||
func (h *authHandler) issueSession(w http.ResponseWriter, r *http.Request, u *domain.User) error {
|
||||
access, accessExp, err := h.signer.Issue(u.ID, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw, hash, err := auth.NewOpaqueToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
refreshExp := time.Now().Add(h.refreshTTL)
|
||||
if err := h.refreshes.Create(r.Context(), storage.CreateRefreshTokenParams{
|
||||
Hash: hash,
|
||||
UserID: u.ID,
|
||||
ExpiresAt: refreshExp,
|
||||
UserAgent: r.UserAgent(),
|
||||
IPAddress: clientIP(r),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
h.setRefreshCookie(w, raw, refreshExp)
|
||||
writeJSON(w, http.StatusOK, authSuccess{
|
||||
AccessToken: access,
|
||||
ExpiresAt: accessExp,
|
||||
User: u,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *authHandler) setRefreshCookie(w http.ResponseWriter, value string, expires time.Time) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: refreshCookieName,
|
||||
Value: value,
|
||||
Path: "/auth",
|
||||
Domain: h.cookieDomain,
|
||||
Expires: expires,
|
||||
MaxAge: int(time.Until(expires).Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: h.cookieSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *authHandler) clearRefreshCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: refreshCookieName,
|
||||
Value: "",
|
||||
Path: "/auth",
|
||||
Domain: h.cookieDomain,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: h.cookieSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// --- requireAuth middleware ---
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const userIDCtxKey ctxKey = iota
|
||||
|
||||
func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) {
|
||||
v, ok := ctx.Value(userIDCtxKey).(uuid.UUID)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func requireAuth(signer *auth.JWTSigner) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(h, "Bearer ") {
|
||||
writeError(w, http.StatusUnauthorized, "missing bearer token")
|
||||
return
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(h, "Bearer "))
|
||||
claims, err := signer.Parse(raw)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrExpiredJWT) {
|
||||
writeError(w, http.StatusUnauthorized, "token expired")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid token")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userIDCtxKey, claims.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// hostFromContext returns the authed user's id, or writes 401 and returns
|
||||
// false. Used by host-facing handlers as the first line in the function.
|
||||
func hostFromContext(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthenticated")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
return uid, true
|
||||
}
|
||||
|
||||
// requireEventOwner fetches the event and confirms the authed user owns it.
|
||||
// On mismatch (or missing event) it returns 404 — never 403 — so a cross-
|
||||
// tenant probe cannot tell the difference between "event doesn't exist" and
|
||||
// "exists but belongs to someone else".
|
||||
func requireEventOwner(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
events *storage.EventRepo,
|
||||
eventID, hostID uuid.UUID,
|
||||
) (*domain.Event, bool) {
|
||||
ev, err := events.GetForHost(r.Context(), eventID, hostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return nil, false
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
return nil, false
|
||||
}
|
||||
return ev, true
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type billingHandler struct {
|
||||
logger *slog.Logger
|
||||
stripe *billing.Client
|
||||
users *storage.UserRepo
|
||||
subscriptions *storage.SubscriptionRepo
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
type checkoutSessionRequest struct {
|
||||
Tier string `json:"tier"`
|
||||
}
|
||||
|
||||
type checkoutSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// POST /billing/checkout-session — returns the Stripe Checkout URL the
|
||||
// frontend redirects the host to. Mints a Stripe customer on first use
|
||||
// and persists it so repeat calls reuse the same customer.
|
||||
func (h *billingHandler) checkoutSession(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.stripeEnabled(w) {
|
||||
return
|
||||
}
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req checkoutSessionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
tier := billing.Tier(strings.ToLower(req.Tier))
|
||||
if tier != billing.TierPro && tier != billing.TierBusiness {
|
||||
writeError(w, http.StatusBadRequest, "tier must be 'pro' or 'business'")
|
||||
return
|
||||
}
|
||||
price, err := h.stripe.PriceFor(tier)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "this tier is not configured yet — contact support")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.users.GetByID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
|
||||
existingCustomerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load billing record")
|
||||
return
|
||||
}
|
||||
customerID, err := h.stripe.CreateOrGetCustomer(hostID.String(), user.Email, user.Name, existingCustomerID)
|
||||
if err != nil {
|
||||
h.logger.Error("stripe customer", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe customer error")
|
||||
return
|
||||
}
|
||||
if existingCustomerID == "" {
|
||||
// First time — write a placeholder row so the customer id sticks.
|
||||
if _, err := h.subscriptions.Upsert(r.Context(), storage.UpsertParams{
|
||||
UserID: hostID,
|
||||
StripeCustomerID: customerID,
|
||||
}); err != nil {
|
||||
h.logger.Error("upsert sub placeholder", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
base := strings.TrimRight(h.publicBaseURL, "/")
|
||||
url, err := h.stripe.CreateCheckoutSession(billing.CheckoutSessionParams{
|
||||
CustomerID: customerID,
|
||||
PriceID: price,
|
||||
SuccessURL: base + "/dashboard?billing=success",
|
||||
CancelURL: base + "/dashboard?billing=cancelled",
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("stripe checkout session", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe checkout error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, checkoutSessionResponse{URL: url})
|
||||
}
|
||||
|
||||
type portalSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// POST /billing/portal — returns the customer portal URL so the user
|
||||
// can manage their payment method, view invoices, or cancel. 404 when
|
||||
// the user has no Stripe customer yet (they're still on free).
|
||||
func (h *billingHandler) portalSession(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.stripeEnabled(w) {
|
||||
return
|
||||
}
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
customerID, err := h.subscriptions.FindCustomerID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load billing record")
|
||||
return
|
||||
}
|
||||
if customerID == "" {
|
||||
writeError(w, http.StatusNotFound, "no billing account yet — subscribe first")
|
||||
return
|
||||
}
|
||||
url, err := h.stripe.CreatePortalSession(customerID, strings.TrimRight(h.publicBaseURL, "/")+"/dashboard")
|
||||
if err != nil {
|
||||
h.logger.Error("stripe portal", "err", err)
|
||||
writeError(w, http.StatusBadGateway, "stripe portal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, portalSessionResponse{URL: url})
|
||||
}
|
||||
|
||||
type subscriptionStatusResponse struct {
|
||||
Tier string `json:"tier"`
|
||||
Status string `json:"status"`
|
||||
CurrentPeriodEnd string `json:"current_period_end,omitempty"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
Limits struct {
|
||||
EventsPerMonth int `json:"events_per_month"`
|
||||
GuestsPerEvent int `json:"guests_per_event"`
|
||||
} `json:"limits"`
|
||||
Usage struct {
|
||||
EventsThisMonth int `json:"events_this_month"`
|
||||
} `json:"usage"`
|
||||
PortalAvailable bool `json:"portal_available"`
|
||||
}
|
||||
|
||||
// GET /billing/status — returns the host's current tier + limits +
|
||||
// usage. The frontend uses this to render the billing page and the
|
||||
// 402-modal copy ("you used X of Y events this month").
|
||||
func (h *billingHandler) status(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tier := billing.TierFree
|
||||
status := "active"
|
||||
var periodEnd string
|
||||
cancelAtPeriodEnd := false
|
||||
portalAvailable := false
|
||||
|
||||
if h.subscriptions != nil {
|
||||
sub, err := h.subscriptions.GetActiveByUser(r.Context(), hostID)
|
||||
switch {
|
||||
case err == nil:
|
||||
tier = billing.Tier(sub.Tier)
|
||||
status = sub.Status
|
||||
cancelAtPeriodEnd = sub.CancelAtPeriodEnd
|
||||
if sub.CurrentPeriodEnd != nil {
|
||||
periodEnd = sub.CurrentPeriodEnd.Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
portalAvailable = sub.StripeCustomerID != ""
|
||||
case errors.Is(err, storage.ErrSubscriptionNotFound):
|
||||
// Free tier — leave defaults.
|
||||
default:
|
||||
writeError(w, http.StatusInternalServerError, "failed to load subscription")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limits := billing.LimitsFor(tier)
|
||||
events, _ := h.subscriptions.CountEventsInCurrentMonth(r.Context(), hostID)
|
||||
|
||||
resp := subscriptionStatusResponse{
|
||||
Tier: string(tier),
|
||||
Status: status,
|
||||
CurrentPeriodEnd: periodEnd,
|
||||
CancelAtPeriodEnd: cancelAtPeriodEnd,
|
||||
PortalAvailable: portalAvailable,
|
||||
}
|
||||
resp.Limits.EventsPerMonth = limits.EventsPerMonth
|
||||
resp.Limits.GuestsPerEvent = limits.GuestsPerEvent
|
||||
resp.Usage.EventsThisMonth = events
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// stripeEnabled returns true if the billing client is configured, else
|
||||
// writes 503 and returns false. The /billing/status path skips this so
|
||||
// the frontend can render a "free tier" page in dev environments.
|
||||
func (h *billingHandler) stripeEnabled(w http.ResponseWriter) bool {
|
||||
if h.stripe == nil || !h.stripe.Enabled() {
|
||||
writeError(w, http.StatusServiceUnavailable, "billing is not configured on this instance")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// tierEnforcer wraps the SubscriptionRepo with policy decisions. Lives
|
||||
// here (not in storage) because the policy is HTTP-shaped: we map the
|
||||
// outcome to 402 + an upgrade URL.
|
||||
type tierEnforcer struct {
|
||||
subs *storage.SubscriptionRepo
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
func newTierEnforcer(subs *storage.SubscriptionRepo, publicBaseURL string) *tierEnforcer {
|
||||
return &tierEnforcer{subs: subs, publicBaseURL: publicBaseURL}
|
||||
}
|
||||
|
||||
// currentTier returns the host's effective tier. ErrSubscriptionNotFound
|
||||
// means "no granting subscription on file" → free. Other DB errors
|
||||
// bubble up.
|
||||
func (e *tierEnforcer) currentTier(ctx context.Context, hostID uuid.UUID) (billing.Tier, error) {
|
||||
if e == nil || e.subs == nil {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
sub, err := e.subs.GetActiveByUser(ctx, hostID)
|
||||
if err != nil {
|
||||
if errors.Is(err, storage.ErrSubscriptionNotFound) {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if !billing.StatusGrantsAccess(sub.Status) {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
tier := billing.Tier(sub.Tier)
|
||||
if !tier.Valid() {
|
||||
return billing.TierFree, nil
|
||||
}
|
||||
return tier, nil
|
||||
}
|
||||
|
||||
// allowEventCreate verifies the host's monthly event budget. Returns
|
||||
// true when the request may proceed. On denial it writes a 402 with the
|
||||
// upgrade hint and returns false.
|
||||
func (e *tierEnforcer) allowEventCreate(w http.ResponseWriter, r *http.Request, hostID uuid.UUID) bool {
|
||||
if e == nil || e.subs == nil {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).EventsPerMonth
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountEventsInCurrentMonth(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count events")
|
||||
return false
|
||||
}
|
||||
if used >= limit {
|
||||
e.writePaymentRequired(w, "events_per_month", tier, used, limit,
|
||||
"You've reached your monthly event limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// allowGuestCreate verifies the per-event guest cap. Same shape as
|
||||
// allowEventCreate.
|
||||
func (e *tierEnforcer) allowGuestCreate(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID) bool {
|
||||
if e == nil || e.subs == nil {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).GuestsPerEvent
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count guests")
|
||||
return false
|
||||
}
|
||||
if used >= limit {
|
||||
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
|
||||
"This event has reached the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// allowGuestImport is the CSV-import variant: check the cap against
|
||||
// existing + incoming row count up-front, before we start the
|
||||
// transaction. Dedup may shrink the actual insert count later — that's
|
||||
// OK, we just stay on the safe side.
|
||||
func (e *tierEnforcer) allowGuestImport(w http.ResponseWriter, r *http.Request, hostID, eventID uuid.UUID, incoming int) bool {
|
||||
if e == nil || e.subs == nil || incoming == 0 {
|
||||
return true
|
||||
}
|
||||
tier, err := e.currentTier(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to check plan")
|
||||
return false
|
||||
}
|
||||
limit := billing.LimitsFor(tier).GuestsPerEvent
|
||||
if limit < 0 {
|
||||
return true
|
||||
}
|
||||
used, err := e.subs.CountGuestsByEvent(r.Context(), eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to count guests")
|
||||
return false
|
||||
}
|
||||
if used+incoming > limit {
|
||||
e.writePaymentRequired(w, "guests_per_event", tier, used, limit,
|
||||
"This import would exceed the guest limit on the "+strings.ToUpper(string(tier))+" plan.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type paymentRequiredBody struct {
|
||||
Error string `json:"error"`
|
||||
Reason string `json:"reason"`
|
||||
Tier string `json:"tier"`
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
UpgradeURL string `json:"upgrade_url"`
|
||||
}
|
||||
|
||||
func (e *tierEnforcer) writePaymentRequired(w http.ResponseWriter, reason string, tier billing.Tier, used, limit int, msg string) {
|
||||
body := paymentRequiredBody{
|
||||
Error: msg,
|
||||
Reason: reason,
|
||||
Tier: string(tier),
|
||||
Used: used,
|
||||
Limit: limit,
|
||||
UpgradeURL: strings.TrimRight(e.publicBaseURL, "/") + "/dashboard/billing",
|
||||
}
|
||||
writeJSON(w, http.StatusPaymentRequired, body)
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/csvimport"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
// 1 MB cap on uploads. With ~200 bytes per row that's ~5,000 guests —
|
||||
// matches the row cap in csvimport.DefaultMaxRows.
|
||||
csvMaxBytes = 1 << 20
|
||||
)
|
||||
|
||||
type csvImportHandler struct {
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type importResponse struct {
|
||||
Added int `json:"added"`
|
||||
Skipped int `json:"skipped"`
|
||||
SkippedEmails []string `json:"skipped_emails,omitempty"`
|
||||
Errors []csvimport.RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"`
|
||||
}
|
||||
|
||||
type previewResponse struct {
|
||||
Rows []csvimport.Row `json:"rows"`
|
||||
Errors []csvimport.RowError `json:"errors,omitempty"`
|
||||
TotalCount int `json:"total_count"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/import/preview — parse + validate but don't write.
|
||||
// Used by the frontend to show a "is this what you meant?" table before commit.
|
||||
func (h *csvImportHandler) preview(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
body, ok := readCSVUpload(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
res, err := csvimport.Parse(body, csvimport.Options{})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, previewResponse{
|
||||
Rows: res.Rows,
|
||||
Errors: res.Errors,
|
||||
TotalCount: res.TotalCount,
|
||||
})
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/import — parse, validate, and commit valid rows
|
||||
// in a single transaction. Rows with row-level errors are reported back
|
||||
// but don't prevent the rest from importing.
|
||||
func (h *csvImportHandler) commit(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
body, ok := readCSVUpload(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
parsed, err := csvimport.Parse(body, csvimport.Options{})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Plan enforcement: prevent an import that, even if perfectly
|
||||
// dedup-free, would exceed the per-event guest cap. We check upfront
|
||||
// (current count + parsed-row count) and reject with 402 if it'd
|
||||
// overflow. False positives on dedup-heavy CSVs are acceptable —
|
||||
// host can dedupe and re-upload.
|
||||
if !h.enforcer.allowGuestImport(w, r, hostID, eventID, len(parsed.Rows)) {
|
||||
return
|
||||
}
|
||||
|
||||
rows := make([]storage.BulkImportRow, 0, len(parsed.Rows))
|
||||
for _, r := range parsed.Rows {
|
||||
rows = append(rows, storage.BulkImportRow{
|
||||
Name: r.Name,
|
||||
Email: r.Email,
|
||||
Phone: r.Phone,
|
||||
PlusOnes: r.PlusOnes,
|
||||
})
|
||||
}
|
||||
|
||||
res, err := h.guests.BulkImportGuests(r.Context(), eventID, rows)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "import failed")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, importResponse{
|
||||
Added: res.Added,
|
||||
Skipped: res.Skipped,
|
||||
SkippedEmails: res.SkippedEmails,
|
||||
Errors: parsed.Errors,
|
||||
TotalCount: parsed.TotalCount,
|
||||
})
|
||||
}
|
||||
|
||||
// GET /events/{id}/guests/import/template — download a sample CSV. Auth is
|
||||
// applied at the route level; ownership is verified so an attacker can't
|
||||
// probe the existence of an event by hitting this endpoint.
|
||||
func (h *csvImportHandler) template(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-import-template.csv"`)
|
||||
_, _ = w.Write([]byte(csvimport.TemplateCSV))
|
||||
}
|
||||
|
||||
// readCSVUpload returns the multipart file body (capped at csvMaxBytes) or
|
||||
// writes an error and returns (nil, false). Accepted shapes:
|
||||
//
|
||||
// - multipart/form-data with a "file" field (the drag-drop UI uses this)
|
||||
// - any other Content-Type — the raw body is treated as the CSV (curl-friendly)
|
||||
func readCSVUpload(w http.ResponseWriter, r *http.Request) (io.ReadCloser, bool) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, csvMaxBytes)
|
||||
if err := r.ParseMultipartForm(csvMaxBytes); err == nil && r.MultipartForm != nil {
|
||||
files := r.MultipartForm.File["file"]
|
||||
if len(files) == 0 {
|
||||
writeError(w, http.StatusBadRequest, `form field "file" is required`)
|
||||
return nil, false
|
||||
}
|
||||
f, err := files[0].Open()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "cannot read uploaded file")
|
||||
return nil, false
|
||||
}
|
||||
return f, true
|
||||
} else if err != nil && !errors.Is(err, http.ErrNotMultipart) {
|
||||
writeError(w, http.StatusBadRequest, "invalid multipart body")
|
||||
return nil, false
|
||||
}
|
||||
// Fall through: raw body as CSV.
|
||||
return r.Body, true
|
||||
}
|
||||
+39
-28
@@ -15,11 +15,11 @@ import (
|
||||
)
|
||||
|
||||
type eventHandler struct {
|
||||
repo *storage.EventRepo
|
||||
repo *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type createEventRequest struct {
|
||||
HostID string `json:"host_id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
EventDate time.Time `json:"event_date"`
|
||||
@@ -32,6 +32,14 @@ type createEventRequest struct {
|
||||
var slugRe = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
|
||||
|
||||
func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !h.enforcer.allowEventCreate(w, r, hostID) {
|
||||
return
|
||||
}
|
||||
|
||||
var req createEventRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
@@ -51,12 +59,6 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hostID, err := uuid.Parse(req.HostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
|
||||
return
|
||||
}
|
||||
|
||||
status := domain.EventStatus(req.Status)
|
||||
if status == "" {
|
||||
status = domain.EventStatusDraft
|
||||
@@ -89,37 +91,31 @@ func (h *eventHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *eventHandler) get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ev, err := h.repo.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
ev, ok := requireEventOwner(w, r, h.repo, id, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ev)
|
||||
}
|
||||
|
||||
func (h *eventHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
limit := atoiOr(q.Get("limit"), 50)
|
||||
offset := atoiOr(q.Get("offset"), 0)
|
||||
|
||||
var hostID uuid.UUID
|
||||
if v := q.Get("host_id"); v != "" {
|
||||
parsed, err := uuid.Parse(v)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "host_id must be a valid uuid")
|
||||
return
|
||||
}
|
||||
hostID = parsed
|
||||
}
|
||||
|
||||
events, err := h.repo.List(r.Context(), hostID, limit, offset)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to list events")
|
||||
@@ -146,6 +142,10 @@ type updateEventRequest struct {
|
||||
}
|
||||
|
||||
func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
@@ -180,7 +180,7 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
params.Status = &s
|
||||
}
|
||||
|
||||
ev, err := h.repo.Update(r.Context(), id, params)
|
||||
ev, err := h.repo.Update(r.Context(), id, hostID, params)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrEventNotFound):
|
||||
@@ -196,11 +196,15 @@ func (h *eventHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
id, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.repo.Delete(r.Context(), id); err != nil {
|
||||
if err := h.repo.Delete(r.Context(), id, hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
@@ -212,7 +216,14 @@ func (h *eventHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func parseIDParam(w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, bool) {
|
||||
raw := r.PathValue(name)
|
||||
return parseRawUUID(w, name, r.PathValue(name))
|
||||
}
|
||||
|
||||
func parseRawUUID(w http.ResponseWriter, name, raw string) (uuid.UUID, bool) {
|
||||
if raw == "" {
|
||||
writeError(w, http.StatusBadRequest, name+" is required")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, name+" must be a valid uuid")
|
||||
|
||||
+107
-9
@@ -10,8 +10,9 @@ import (
|
||||
)
|
||||
|
||||
type guestHandler struct {
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
guests *storage.GuestRepo
|
||||
events *storage.EventRepo
|
||||
enforcer *tierEnforcer
|
||||
}
|
||||
|
||||
type createGuestRequest struct {
|
||||
@@ -24,16 +25,18 @@ type createGuestRequest struct {
|
||||
}
|
||||
|
||||
func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := h.events.Get(r.Context(), eventID); err != nil {
|
||||
if errors.Is(err, domain.ErrEventNotFound) {
|
||||
writeError(w, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load event")
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
if !h.enforcer.allowGuestCreate(w, r, hostID, eventID) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -67,11 +70,106 @@ func (h *guestHandler) create(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusCreated, g)
|
||||
}
|
||||
|
||||
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
type updateGuestRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Email *string `json:"email"`
|
||||
Phone *string `json:"phone"`
|
||||
PlusOnes *int `json:"plus_ones"`
|
||||
}
|
||||
|
||||
// PATCH /events/{id}/guests/{guest_id} — patch a guest's contact info.
|
||||
// Fields omitted from the body are left untouched. Empty strings for
|
||||
// email/phone clear those columns to NULL.
|
||||
func (h *guestHandler) update(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req updateGuestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if req.Name != nil && *req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name cannot be empty")
|
||||
return
|
||||
}
|
||||
if req.PlusOnes != nil && *req.PlusOnes < 0 {
|
||||
writeError(w, http.StatusBadRequest, "plus_ones must be >= 0")
|
||||
return
|
||||
}
|
||||
|
||||
g, err := h.guests.Update(r.Context(), eventID, guestID, storage.UpdateGuestParams{
|
||||
Name: req.Name,
|
||||
Email: req.Email,
|
||||
Phone: req.Phone,
|
||||
PlusOnes: req.PlusOnes,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to update guest")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, g)
|
||||
}
|
||||
|
||||
// DELETE /events/{id}/guests/{guest_id} — remove a guest from an event.
|
||||
// Cascade-deletes their token, rsvp, access logs, notifications.
|
||||
func (h *guestHandler) delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
if err := h.guests.Delete(r.Context(), eventID, guestID); err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete guest")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *guestHandler) list(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
limit := atoiOr(q.Get("limit"), 100)
|
||||
offset := atoiOr(q.Get("offset"), 0)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type meHandler struct {
|
||||
users *storage.UserRepo
|
||||
}
|
||||
|
||||
// GET /me — returns the authenticated user. Used by the frontend to bootstrap
|
||||
// after a page reload (with a fresh access token from /auth/refresh).
|
||||
func (h *meHandler) get(w http.ResponseWriter, r *http.Request) {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthenticated")
|
||||
return
|
||||
}
|
||||
u, err := h.users.GetByID(r.Context(), uid)
|
||||
if err != nil {
|
||||
if err == domain.ErrUserNotFound {
|
||||
writeError(w, http.StatusUnauthorized, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, u)
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// privacyHandler holds the GDPR-style "your data, your choice" endpoints:
|
||||
// data export, account deletion, and terms-acceptance recording.
|
||||
type privacyHandler struct {
|
||||
logger *slog.Logger
|
||||
users *storage.UserRepo
|
||||
events *storage.EventRepo
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
rsvps *storage.RSVPRepo
|
||||
access *storage.AccessLogRepo
|
||||
notifs *storage.DB // raw pool access for the export queries
|
||||
refresh *storage.RefreshTokenRepo
|
||||
}
|
||||
|
||||
// DataExport is the shape of the JSON the host downloads from
|
||||
// GET /me/data-export. We don't paginate or stream — for the scale
|
||||
// GuestGuard hosts have, a single response is reasonable. If a host
|
||||
// ever has 100k+ access logs we'll switch to async + email-a-link.
|
||||
type DataExport struct {
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Format string `json:"format"`
|
||||
User *domain.User `json:"user"`
|
||||
Events []*domain.Event `json:"events"`
|
||||
Guests []*domain.Guest `json:"guests"`
|
||||
Tokens []exportedToken `json:"tokens"`
|
||||
RSVPs []exportedRSVP `json:"rsvps"`
|
||||
AccessLogs []exportedAccess `json:"access_logs"`
|
||||
Notifs []exportedNotif `json:"notifications"`
|
||||
}
|
||||
|
||||
type exportedToken struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type exportedRSVP struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
Response string `json:"response"`
|
||||
PlusOnes int `json:"plus_ones"`
|
||||
SubmittedAt time.Time `json:"submitted_at"`
|
||||
}
|
||||
type exportedAccess struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
RiskScore *int `json:"risk_score,omitempty"`
|
||||
Flagged bool `json:"flagged"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type exportedNotif struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
GuestID uuid.UUID `json:"guest_id"`
|
||||
Channel string `json:"channel"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// GET /me/data-export — returns every record the system holds about the
|
||||
// authenticated user. The Content-Disposition header makes browsers
|
||||
// offer a download rather than rendering inline.
|
||||
func (h *privacyHandler) dataExport(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.users.GetByID(r.Context(), hostID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
|
||||
export := DataExport{
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Format: "guestguard.v1",
|
||||
User: user,
|
||||
}
|
||||
|
||||
// Events the user hosts.
|
||||
events, err := h.events.List(r.Context(), hostID, 1000, 0)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load events")
|
||||
return
|
||||
}
|
||||
export.Events = events
|
||||
|
||||
// For each event, pull guests + tokens + rsvps + access_logs + notifications.
|
||||
for _, ev := range events {
|
||||
guests, err := h.guests.ListByEvent(r.Context(), ev.ID, 5000, 0)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guests")
|
||||
return
|
||||
}
|
||||
export.Guests = append(export.Guests, guests...)
|
||||
|
||||
for _, g := range guests {
|
||||
// Token (at most one per guest, but query as a list for symmetry).
|
||||
if err := h.appendTokens(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: tokens", "err", err)
|
||||
}
|
||||
if err := h.appendRSVPs(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: rsvps", "err", err)
|
||||
}
|
||||
if err := h.appendAccess(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: access", "err", err)
|
||||
}
|
||||
if err := h.appendNotifs(r.Context(), g.ID, &export); err != nil {
|
||||
h.logger.Warn("export: notifs", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="guestguard-data-export.json"`)
|
||||
writeJSON(w, http.StatusOK, export)
|
||||
}
|
||||
|
||||
func (h *privacyHandler) appendTokens(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, expires_at, status, created_at
|
||||
FROM tokens WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var t exportedToken
|
||||
if err := rows.Scan(&t.ID, &t.GuestID, &t.ExpiresAt, &t.Status, &t.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Tokens = append(out.Tokens, t)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendRSVPs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, response::text, plus_ones, submitted_at
|
||||
FROM rsvps WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var r exportedRSVP
|
||||
if err := rows.Scan(&r.ID, &r.GuestID, &r.Response, &r.PlusOnes, &r.SubmittedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.RSVPs = append(out.RSVPs, r)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendAccess(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, risk_score, flagged, created_at
|
||||
FROM access_logs WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var a exportedAccess
|
||||
var rs *int
|
||||
if err := rows.Scan(&a.ID, &a.GuestID, &rs, &a.Flagged, &a.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
a.RiskScore = rs
|
||||
out.AccessLogs = append(out.AccessLogs, a)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
func (h *privacyHandler) appendNotifs(ctx context.Context, guestID uuid.UUID, out *DataExport) error {
|
||||
rows, err := h.notifs.Pool.Query(ctx, `
|
||||
SELECT id, guest_id, channel::text, type::text, status::text, created_at
|
||||
FROM notifications WHERE guest_id = $1
|
||||
`, guestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var n exportedNotif
|
||||
if err := rows.Scan(&n.ID, &n.GuestID, &n.Channel, &n.Type, &n.Status, &n.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Notifs = append(out.Notifs, n)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// DELETE /me — soft-deletes the host's account. All sessions are
|
||||
// revoked immediately. A hard delete happens via a separate cron 30
|
||||
// days later (TBD ops work). The user is logged out from all devices
|
||||
// as a side effect of revoking the refresh tokens.
|
||||
func (h *privacyHandler) deleteMe(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.users.SoftDelete(r.Context(), hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
writeError(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete account")
|
||||
return
|
||||
}
|
||||
// Best-effort: revoke refresh tokens so other sessions log out too.
|
||||
// Failure here is logged but doesn't roll back the soft-delete — the
|
||||
// access tokens (JWT) will still expire on their own ~15 minute TTL.
|
||||
if err := h.refresh.RevokeAllForUser(r.Context(), hostID); err != nil {
|
||||
h.logger.Warn("delete-me: revoke refresh tokens", "err", err, "user_id", hostID)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /me/accept-terms — records that the authenticated user accepts
|
||||
// the current ToS + privacy policy. Idempotent. Used by both the
|
||||
// onboarding gate (existing accounts created before T&C were enforced)
|
||||
// and any future "we updated our terms" re-acceptance flow.
|
||||
func (h *privacyHandler) acceptTerms(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := h.users.AcceptTerms(r.Context(), hostID); err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
writeError(w, http.StatusNotFound, "user not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to record acceptance")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "accepted"})
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ipKey is the rate-limit key for endpoints scoped by source IP only
|
||||
// (e.g. POST /auth/signup). XFF/X-Real-IP are honoured because in the
|
||||
// homelab the API sits behind Traefik.
|
||||
func ipKey(r *http.Request) string {
|
||||
return clientIP(r)
|
||||
}
|
||||
|
||||
// pathKey returns a path-parameter as the rate-limit key — used for the
|
||||
// token-scoped endpoints so an attacker brute-forcing a single token is
|
||||
// limited regardless of the IPs they rotate through.
|
||||
func pathKey(name string) KeyFunc {
|
||||
return func(r *http.Request) string {
|
||||
return r.PathValue(name)
|
||||
}
|
||||
}
|
||||
|
||||
// userIDKey extracts the authenticated user id from the request context.
|
||||
// Returns "" when the route isn't behind requireAuth, in which case the
|
||||
// middleware bypasses (fail-open) — the route's own auth layer handles
|
||||
// rejection.
|
||||
func userIDKey(r *http.Request) string {
|
||||
uid, ok := UserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return uid.String()
|
||||
}
|
||||
|
||||
// KeyFunc mirrors ratelimit.KeyFunc so call sites don't have to import the
|
||||
// inner package.
|
||||
type KeyFunc = func(r *http.Request) string
|
||||
+264
-43
@@ -5,63 +5,167 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/auth"
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
"github.com/alchemistkay/guestguard/internal/ratelimit"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
logger *slog.Logger
|
||||
db *storage.DB
|
||||
hub *Hub
|
||||
users *userHandler
|
||||
events *eventHandler
|
||||
guests *guestHandler
|
||||
tokens *tokenHandler
|
||||
rsvps *rsvpHandler
|
||||
activity *activityHandler
|
||||
ws *wsHandler
|
||||
health *healthHandler
|
||||
logger *slog.Logger
|
||||
db *storage.DB
|
||||
hub *Hub
|
||||
authH *authHandler
|
||||
me *meHandler
|
||||
events *eventHandler
|
||||
guests *guestHandler
|
||||
tokens *tokenHandler
|
||||
rsvps *rsvpHandler
|
||||
activity *activityHandler
|
||||
ws *wsHandler
|
||||
wsTicket *wsTicketHandler
|
||||
health *healthHandler
|
||||
signer *auth.JWTSigner
|
||||
limiter *ratelimit.Limiter
|
||||
unsub *unsubscribeHandler
|
||||
webhooks *webhookHandler
|
||||
csv *csvImportHandler
|
||||
billing *billingHandler
|
||||
stripeWH *stripeWebhookHandler
|
||||
privacy *privacyHandler
|
||||
}
|
||||
|
||||
type ServerDeps struct {
|
||||
Logger *slog.Logger
|
||||
DB *storage.DB
|
||||
Hub *Hub
|
||||
AccessPublisher accessPublisher
|
||||
RSVPPublisher rsvpPublisher
|
||||
FraudScorer fraudScorer
|
||||
TokenTTL time.Duration
|
||||
AccessPublisher accessPublisher
|
||||
RSVPPublisher rsvpPublisher
|
||||
InvitationPublisher invitationPublisher
|
||||
FraudScorer fraudScorer
|
||||
TokenTTL time.Duration
|
||||
|
||||
// Auth
|
||||
JWTSecret string
|
||||
JWTIssuer string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
EmailVerificationTTL time.Duration
|
||||
PasswordResetTTL time.Duration
|
||||
PublicBaseURL string
|
||||
RefreshCookieDomain string
|
||||
RefreshCookieSecure bool
|
||||
EmailSender auth.EmailSender
|
||||
WSTicketTTL time.Duration
|
||||
|
||||
// Rate limiting / abuse controls
|
||||
Redis *redis.Client
|
||||
LoginLockoutMax int // failed attempts before account lockout (default 5)
|
||||
LoginFailWindow time.Duration // counter TTL (default 15 min)
|
||||
|
||||
// Notifications / unsubscribe
|
||||
NotificationRepo *notification.Repo
|
||||
SuppressionRepo *notification.SuppressionRepo
|
||||
UnsubscribeSigner *notification.UnsubscribeSigner
|
||||
|
||||
// Billing (Block F). Nil StripeClient leaves billing disabled — the
|
||||
// system still boots and runs, all users sit on the free tier with
|
||||
// its limits enforced; /billing/* returns 503.
|
||||
StripeClient *billing.Client
|
||||
}
|
||||
|
||||
func NewServer(deps ServerDeps) *Server {
|
||||
func NewServer(deps ServerDeps) (*Server, error) {
|
||||
eventRepo := storage.NewEventRepo(deps.DB)
|
||||
guestRepo := storage.NewGuestRepo(deps.DB)
|
||||
tokenRepo := storage.NewTokenRepo(deps.DB)
|
||||
rsvpRepo := storage.NewRSVPRepo(deps.DB)
|
||||
accessRepo := storage.NewAccessLogRepo(deps.DB)
|
||||
userRepo := storage.NewUserRepo(deps.DB)
|
||||
verifRepo := storage.NewEmailVerificationRepo(deps.DB)
|
||||
resetRepo := storage.NewPasswordResetRepo(deps.DB)
|
||||
refreshRepo := storage.NewRefreshTokenRepo(deps.DB)
|
||||
subRepo := storage.NewSubscriptionRepo(deps.DB)
|
||||
enforcer := newTierEnforcer(subRepo, deps.PublicBaseURL)
|
||||
|
||||
signer, err := auth.NewJWTSigner(deps.JWTSecret, deps.AccessTokenTTL, deps.JWTIssuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hasher := auth.NewPasswordHasher()
|
||||
|
||||
emails := deps.EmailSender
|
||||
if emails == nil {
|
||||
emails = auth.LogEmailSender{Logger: deps.Logger}
|
||||
}
|
||||
|
||||
hub := deps.Hub
|
||||
if hub == nil {
|
||||
hub = NewHub(deps.Logger)
|
||||
}
|
||||
|
||||
wsTicketTTL := deps.WSTicketTTL
|
||||
if wsTicketTTL <= 0 {
|
||||
wsTicketTTL = 60 * time.Second
|
||||
}
|
||||
wsTickets := newWSTicketStore(wsTicketTTL)
|
||||
|
||||
var limiter *ratelimit.Limiter
|
||||
var lockout *auth.LockoutTracker
|
||||
if deps.Redis != nil {
|
||||
limiter = ratelimit.New(deps.Redis, "gg:rl")
|
||||
lockoutMax := deps.LoginLockoutMax
|
||||
if lockoutMax <= 0 {
|
||||
lockoutMax = 5
|
||||
}
|
||||
failWindow := deps.LoginFailWindow
|
||||
if failWindow <= 0 {
|
||||
failWindow = 15 * time.Minute
|
||||
}
|
||||
lockout = auth.NewLockoutTracker(deps.Redis, lockoutMax, failWindow)
|
||||
}
|
||||
|
||||
authH := newAuthHandler(authHandlerDeps{
|
||||
Logger: deps.Logger,
|
||||
Users: userRepo,
|
||||
Verifications: verifRepo,
|
||||
Resets: resetRepo,
|
||||
Refreshes: refreshRepo,
|
||||
Hasher: hasher,
|
||||
Signer: signer,
|
||||
Emails: emails,
|
||||
Lockout: lockout,
|
||||
Limiter: limiter,
|
||||
PublicBaseURL: deps.PublicBaseURL,
|
||||
EmailVerificationTTL: deps.EmailVerificationTTL,
|
||||
PasswordResetTTL: deps.PasswordResetTTL,
|
||||
RefreshTTL: deps.RefreshTokenTTL,
|
||||
CookieDomain: deps.RefreshCookieDomain,
|
||||
CookieSecure: deps.RefreshCookieSecure,
|
||||
})
|
||||
|
||||
return &Server{
|
||||
logger: deps.Logger,
|
||||
db: deps.DB,
|
||||
hub: hub,
|
||||
users: &userHandler{repo: userRepo},
|
||||
events: &eventHandler{repo: eventRepo},
|
||||
guests: &guestHandler{guests: guestRepo, events: eventRepo},
|
||||
authH: authH,
|
||||
me: &meHandler{users: userRepo},
|
||||
events: &eventHandler{repo: eventRepo, enforcer: enforcer},
|
||||
guests: &guestHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
|
||||
tokens: &tokenHandler{
|
||||
logger: deps.Logger,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
events: eventRepo,
|
||||
accessLogs: accessRepo,
|
||||
gen: auth.NewGenerator(),
|
||||
ttl: deps.TokenTTL,
|
||||
pub: deps.AccessPublisher,
|
||||
logger: deps.Logger,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
events: eventRepo,
|
||||
users: userRepo,
|
||||
accessLogs: accessRepo,
|
||||
gen: auth.NewGenerator(),
|
||||
ttl: deps.TokenTTL,
|
||||
pub: deps.AccessPublisher,
|
||||
invitations: deps.InvitationPublisher,
|
||||
publicBaseURL: deps.PublicBaseURL,
|
||||
},
|
||||
rsvps: &rsvpHandler{
|
||||
logger: deps.Logger,
|
||||
@@ -78,9 +182,46 @@ func NewServer(deps ServerDeps) *Server {
|
||||
rsvps: rsvpRepo,
|
||||
accessLogs: accessRepo,
|
||||
},
|
||||
ws: &wsHandler{logger: deps.Logger, hub: hub},
|
||||
health: &healthHandler{pool: deps.DB.Pool},
|
||||
}
|
||||
ws: &wsHandler{logger: deps.Logger, hub: hub, tickets: wsTickets},
|
||||
wsTicket: &wsTicketHandler{tickets: wsTickets, events: eventRepo},
|
||||
health: &healthHandler{pool: deps.DB.Pool},
|
||||
signer: signer,
|
||||
limiter: limiter,
|
||||
unsub: &unsubscribeHandler{
|
||||
logger: deps.Logger,
|
||||
signer: deps.UnsubscribeSigner,
|
||||
suppress: deps.SuppressionRepo,
|
||||
},
|
||||
webhooks: &webhookHandler{
|
||||
logger: deps.Logger,
|
||||
notifs: deps.NotificationRepo,
|
||||
suppress: deps.SuppressionRepo,
|
||||
},
|
||||
csv: &csvImportHandler{guests: guestRepo, events: eventRepo, enforcer: enforcer},
|
||||
billing: &billingHandler{
|
||||
logger: deps.Logger,
|
||||
stripe: deps.StripeClient,
|
||||
users: userRepo,
|
||||
subscriptions: subRepo,
|
||||
publicBaseURL: deps.PublicBaseURL,
|
||||
},
|
||||
stripeWH: &stripeWebhookHandler{
|
||||
logger: deps.Logger,
|
||||
stripe: deps.StripeClient,
|
||||
subs: subRepo,
|
||||
},
|
||||
privacy: &privacyHandler{
|
||||
logger: deps.Logger,
|
||||
users: userRepo,
|
||||
events: eventRepo,
|
||||
guests: guestRepo,
|
||||
tokens: tokenRepo,
|
||||
rsvps: rsvpRepo,
|
||||
access: accessRepo,
|
||||
notifs: deps.DB,
|
||||
refresh: refreshRepo,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Hub() *Hub { return s.hub }
|
||||
@@ -91,25 +232,104 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.HandleFunc("GET /health", s.health.live)
|
||||
mux.HandleFunc("GET /health/ready", s.health.ready)
|
||||
|
||||
mux.HandleFunc("POST /users", s.users.upsert)
|
||||
// Per-route rate limiters (no-op when Redis isn't wired).
|
||||
authed := requireAuth(s.signer)
|
||||
rl := func(name string, limit int, window time.Duration, keyFn KeyFunc, h http.Handler) http.Handler {
|
||||
if s.limiter == nil {
|
||||
return h
|
||||
}
|
||||
return s.limiter.Middleware(
|
||||
ratelimit.Rule{Name: name, Limit: limit, Window: window},
|
||||
keyFn,
|
||||
s.logger,
|
||||
)(h)
|
||||
}
|
||||
|
||||
mux.HandleFunc("POST /events", s.events.create)
|
||||
mux.HandleFunc("GET /events", s.events.list)
|
||||
mux.HandleFunc("GET /events/{id}", s.events.get)
|
||||
mux.HandleFunc("PATCH /events/{id}", s.events.update)
|
||||
mux.HandleFunc("DELETE /events/{id}", s.events.delete)
|
||||
// Anonymous auth endpoints — POST /auth/login + /auth/forgot-password
|
||||
// rate-limit inside the handler (key includes the email body field).
|
||||
mux.Handle("POST /auth/signup",
|
||||
rl("auth_signup", 5, time.Hour, ipKey, http.HandlerFunc(s.authH.signup)))
|
||||
mux.HandleFunc("POST /auth/login", s.authH.login)
|
||||
mux.HandleFunc("POST /auth/refresh", s.authH.refresh)
|
||||
mux.HandleFunc("POST /auth/logout", s.authH.logout)
|
||||
mux.HandleFunc("POST /auth/verify-email", s.authH.verifyEmail)
|
||||
mux.HandleFunc("POST /auth/forgot-password", s.authH.forgotPassword)
|
||||
mux.HandleFunc("POST /auth/reset-password", s.authH.resetPassword)
|
||||
|
||||
mux.HandleFunc("POST /events/{id}/guests", s.guests.create)
|
||||
mux.HandleFunc("GET /events/{id}/guests", s.guests.list)
|
||||
mux.Handle("GET /me", authed(http.HandlerFunc(s.me.get)))
|
||||
mux.Handle("POST /auth/ws-ticket", authed(http.HandlerFunc(s.wsTicket.issue)))
|
||||
|
||||
mux.HandleFunc("GET /events/{id}/activity", s.activity.list)
|
||||
// Privacy / GDPR-style endpoints — host can export their data,
|
||||
// delete their account, and record terms acceptance from the
|
||||
// onboarding gate.
|
||||
mux.Handle("GET /me/data-export", authed(http.HandlerFunc(s.privacy.dataExport)))
|
||||
mux.Handle("DELETE /me", authed(http.HandlerFunc(s.privacy.deleteMe)))
|
||||
mux.Handle("POST /me/accept-terms", authed(http.HandlerFunc(s.privacy.acceptTerms)))
|
||||
|
||||
mux.HandleFunc("POST /events/{id}/guests/{guest_id}/tokens", s.tokens.issue)
|
||||
mux.HandleFunc("GET /access/{token}", s.tokens.access)
|
||||
mux.HandleFunc("POST /rsvp/{token}", s.rsvps.submit)
|
||||
// Host-facing event/guest/token writes are limited by user_id.
|
||||
mux.Handle("POST /events",
|
||||
authed(rl("events_create", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.events.create))))
|
||||
mux.Handle("GET /events", authed(http.HandlerFunc(s.events.list)))
|
||||
mux.Handle("GET /events/{id}", authed(http.HandlerFunc(s.events.get)))
|
||||
mux.Handle("PATCH /events/{id}", authed(http.HandlerFunc(s.events.update)))
|
||||
mux.Handle("DELETE /events/{id}", authed(http.HandlerFunc(s.events.delete)))
|
||||
|
||||
mux.Handle("POST /events/{id}/guests",
|
||||
authed(rl("guests_create", 1000, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.create))))
|
||||
mux.Handle("GET /events/{id}/guests", authed(http.HandlerFunc(s.guests.list)))
|
||||
mux.Handle("PATCH /events/{id}/guests/{guest_id}",
|
||||
authed(rl("guests_update", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.update))))
|
||||
mux.Handle("DELETE /events/{id}/guests/{guest_id}",
|
||||
authed(rl("guests_delete", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.guests.delete))))
|
||||
|
||||
// CSV import (Block E). Preview is cheap (no DB writes), so we keep
|
||||
// its budget separate from commit's daily-row-add limit.
|
||||
mux.Handle("POST /events/{id}/guests/import/preview",
|
||||
authed(rl("guests_import_preview", 30, time.Hour, userIDKey, http.HandlerFunc(s.csv.preview))))
|
||||
mux.Handle("POST /events/{id}/guests/import",
|
||||
authed(rl("guests_import_commit", 20, 24*time.Hour, userIDKey, http.HandlerFunc(s.csv.commit))))
|
||||
mux.Handle("GET /events/{id}/guests/import/template", authed(http.HandlerFunc(s.csv.template)))
|
||||
|
||||
mux.Handle("GET /events/{id}/activity", authed(http.HandlerFunc(s.activity.list)))
|
||||
|
||||
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens",
|
||||
authed(rl("tokens_issue", 500, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.issue))))
|
||||
mux.Handle("POST /events/{id}/guests/{guest_id}/tokens/rotate",
|
||||
authed(rl("tokens_rotate", 200, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.rotate))))
|
||||
mux.Handle("POST /events/{id}/guests/invitations/bulk",
|
||||
authed(rl("tokens_bulk", 10, 24*time.Hour, userIDKey, http.HandlerFunc(s.tokens.bulkIssue))))
|
||||
|
||||
// Guest-facing endpoints — rate-limited by the access token in the URL
|
||||
// path so an attacker hammering a single invitation is slowed regardless
|
||||
// of their source IP.
|
||||
mux.Handle("GET /access/{token}",
|
||||
rl("access", 60, time.Hour, pathKey("token"), http.HandlerFunc(s.tokens.access)))
|
||||
mux.Handle("POST /rsvp/{token}",
|
||||
rl("rsvp", 10, time.Hour, pathKey("token"), http.HandlerFunc(s.rsvps.submit)))
|
||||
|
||||
// WebSocket endpoint authenticates via single-use ticket on the query
|
||||
// string (see POST /auth/ws-ticket).
|
||||
mux.HandleFunc("GET /ws/events/{id}", s.ws.handle)
|
||||
|
||||
// Unsubscribe (signed token, no auth required — links live in emails).
|
||||
mux.HandleFunc("GET /unsubscribe/{token}", s.unsub.preview)
|
||||
mux.HandleFunc("POST /unsubscribe/{token}", s.unsub.confirm)
|
||||
|
||||
// Provider webhooks. Signature verification is enforced in the handler
|
||||
// once GG_TWILIO_AUTH_TOKEN / GG_SES_WEBHOOK_SECRET are set.
|
||||
mux.HandleFunc("POST /webhooks/twilio/status", s.webhooks.twilio)
|
||||
mux.HandleFunc("POST /webhooks/ses/notifications", s.webhooks.ses)
|
||||
|
||||
// Billing (Block F). /billing/status is safe for everyone — returns
|
||||
// free tier defaults when Stripe is unconfigured or the user has no
|
||||
// subscription, so the frontend's plan page always has something to
|
||||
// render. The action endpoints (checkout, portal) return 503 in dev
|
||||
// without Stripe credentials.
|
||||
mux.Handle("GET /billing/status", authed(http.HandlerFunc(s.billing.status)))
|
||||
mux.Handle("POST /billing/checkout-session", authed(http.HandlerFunc(s.billing.checkoutSession)))
|
||||
mux.Handle("POST /billing/portal", authed(http.HandlerFunc(s.billing.portalSession)))
|
||||
mux.HandleFunc("POST /webhooks/stripe", s.stripeWH.handle)
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusNotFound, "not found")
|
||||
})
|
||||
@@ -128,9 +348,10 @@ func corsMiddleware(next http.Handler) http.Handler {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Device-Fingerprint")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Device-Fingerprint")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/stripe/stripe-go/v82"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/billing"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
// stripeWebhookHandler accepts and verifies Stripe events, then
|
||||
// projects subscription lifecycle changes onto the subscriptions table.
|
||||
// We track only what middleware needs to decide access — tier + status +
|
||||
// period bounds. Invoice events (payment failed / succeeded) are logged
|
||||
// for observability; dunning automation lands in Block F3.
|
||||
type stripeWebhookHandler struct {
|
||||
logger *slog.Logger
|
||||
stripe *billing.Client
|
||||
subs *storage.SubscriptionRepo
|
||||
}
|
||||
|
||||
// POST /webhooks/stripe — signature-verified Stripe event sink.
|
||||
func (h *stripeWebhookHandler) handle(w http.ResponseWriter, r *http.Request) {
|
||||
if h.stripe == nil || !h.stripe.Enabled() {
|
||||
// Not configured on this instance — reject so a misrouted event
|
||||
// isn't silently swallowed. Stripe will retry which is harmless.
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "read body")
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
event, err := h.stripe.VerifyWebhook(body, r.Header.Get("Stripe-Signature"))
|
||||
if err != nil {
|
||||
h.logger.Warn("stripe webhook signature failed", "err", err)
|
||||
writeError(w, http.StatusBadRequest, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "customer.subscription.created", "customer.subscription.updated":
|
||||
h.applySubscription(r, event)
|
||||
case "customer.subscription.deleted":
|
||||
h.applySubscriptionDeleted(r, event)
|
||||
case "invoice.payment_succeeded":
|
||||
// Clear past_due if Stripe says payment caught up. Most flows are
|
||||
// already covered by the subscription.updated event Stripe also
|
||||
// fires — this is belt-and-braces.
|
||||
h.applySubscription(r, event)
|
||||
case "invoice.payment_failed":
|
||||
h.logger.Warn("stripe invoice payment failed", "event_id", event.ID)
|
||||
// Subscription.status will flip to past_due via the
|
||||
// subscription.updated event Stripe fires alongside.
|
||||
default:
|
||||
h.logger.Debug("stripe event ignored", "type", event.Type)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// applySubscription patches the subscriptions row keyed by Stripe
|
||||
// customer id. Best-effort — failures here are logged but don't NACK
|
||||
// the webhook (Stripe would retry forever and the row would never
|
||||
// converge).
|
||||
func (h *stripeWebhookHandler) applySubscription(r *http.Request, event stripe.Event) {
|
||||
var sub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
// Some invoice events carry an Invoice payload — try to extract
|
||||
// the subscription id from there and short-circuit on status.
|
||||
h.logger.Debug("stripe webhook: not a subscription payload", "type", event.Type, "err", err)
|
||||
return
|
||||
}
|
||||
if sub.Customer == nil || sub.Customer.ID == "" {
|
||||
h.logger.Warn("stripe webhook: subscription has no customer", "subscription", sub.ID)
|
||||
return
|
||||
}
|
||||
|
||||
tier := tierFromSubscription(&sub)
|
||||
status := string(sub.Status)
|
||||
cancelAtPeriodEnd := sub.CancelAtPeriodEnd
|
||||
|
||||
// As of API 2024-10-28, current_period_end lives on the subscription
|
||||
// item, not the subscription. We pick the earliest item's end — for
|
||||
// single-item subscriptions (our case) that's the canonical one.
|
||||
var periodEnd *time.Time
|
||||
for _, item := range sub.Items.Data {
|
||||
if item.CurrentPeriodEnd > 0 {
|
||||
t := time.Unix(item.CurrentPeriodEnd, 0).UTC()
|
||||
if periodEnd == nil || t.Before(*periodEnd) {
|
||||
periodEnd = &t
|
||||
}
|
||||
}
|
||||
}
|
||||
subID := sub.ID
|
||||
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
|
||||
StripeSubscriptionID: &subID,
|
||||
Tier: stringPtr(string(tier)),
|
||||
Status: &status,
|
||||
CurrentPeriodEnd: periodEnd,
|
||||
CancelAtPeriodEnd: &cancelAtPeriodEnd,
|
||||
}); err != nil {
|
||||
h.logger.Error("stripe webhook: update subscription failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *stripeWebhookHandler) applySubscriptionDeleted(r *http.Request, event stripe.Event) {
|
||||
var sub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
||||
h.logger.Warn("stripe webhook: bad deleted payload", "err", err)
|
||||
return
|
||||
}
|
||||
if sub.Customer == nil {
|
||||
return
|
||||
}
|
||||
status := "canceled"
|
||||
if err := h.subs.UpdateByCustomer(r.Context(), sub.Customer.ID, storage.UpsertParams{
|
||||
Status: &status,
|
||||
}); err != nil {
|
||||
h.logger.Error("stripe webhook: mark canceled failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tierFromSubscription inspects the Stripe price metadata to figure out
|
||||
// which GuestGuard tier this subscription corresponds to. We read a
|
||||
// price-level metadata key `gg_tier` (set in the Stripe dashboard when
|
||||
// you create the Price). Fallback: free.
|
||||
func tierFromSubscription(sub *stripe.Subscription) billing.Tier {
|
||||
if sub == nil || len(sub.Items.Data) == 0 {
|
||||
return billing.TierFree
|
||||
}
|
||||
for _, item := range sub.Items.Data {
|
||||
if item.Price == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := item.Price.Metadata["gg_tier"]; ok {
|
||||
t := billing.Tier(v)
|
||||
if t.Valid() {
|
||||
return t
|
||||
}
|
||||
}
|
||||
// Heuristic fallback for tests / unconfigured prices: look at the
|
||||
// recurring interval and amount tier.
|
||||
if item.Price.Recurring != nil && item.Price.UnitAmount >= 19900 {
|
||||
return billing.TierBusiness
|
||||
}
|
||||
if item.Price.Recurring != nil && item.Price.UnitAmount >= 4900 {
|
||||
return billing.TierPro
|
||||
}
|
||||
}
|
||||
return billing.TierFree
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string { return &s }
|
||||
+329
-14
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -20,25 +21,38 @@ type accessPublisher interface {
|
||||
PublishAccessAttempted(ctx context.Context, evt natspub.AccessAttempted) error
|
||||
}
|
||||
|
||||
type invitationPublisher interface {
|
||||
PublishInvitationSend(ctx context.Context, evt natspub.InvitationSend) error
|
||||
}
|
||||
|
||||
type tokenHandler struct {
|
||||
logger *slog.Logger
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
events *storage.EventRepo
|
||||
accessLogs *storage.AccessLogRepo
|
||||
gen *auth.Generator
|
||||
ttl time.Duration
|
||||
pub accessPublisher
|
||||
logger *slog.Logger
|
||||
guests *storage.GuestRepo
|
||||
tokens *storage.TokenRepo
|
||||
events *storage.EventRepo
|
||||
users *storage.UserRepo
|
||||
accessLogs *storage.AccessLogRepo
|
||||
gen *auth.Generator
|
||||
ttl time.Duration
|
||||
pub accessPublisher
|
||||
invitations invitationPublisher
|
||||
publicBaseURL string
|
||||
}
|
||||
|
||||
type issueTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
Meta *domain.Token `json:"meta"`
|
||||
Token string `json:"token"`
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
Meta *domain.Token `json:"meta"`
|
||||
InvitationQueued bool `json:"invitation_queued"`
|
||||
InvitationLink string `json:"invitation_link"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/{guest_id}/tokens — issue a token for the guest.
|
||||
func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
@@ -47,6 +61,10 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
guest, err := h.guests.Get(r.Context(), guestID)
|
||||
if err != nil {
|
||||
@@ -78,10 +96,307 @@ func (h *tokenHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
link := h.invitationLink(raw)
|
||||
invitationQueued := h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
|
||||
|
||||
writeJSON(w, http.StatusCreated, issueTokenResponse{
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
InvitationQueued: invitationQueued,
|
||||
InvitationLink: link,
|
||||
})
|
||||
}
|
||||
|
||||
// queueInvitation publishes an invitation.send event so the notifier can
|
||||
// dispatch a branded email. Best-effort: if any step fails we log and
|
||||
// return false rather than failing the whole token-issue request — the
|
||||
// host still has the raw URL in the response and can re-trigger sending.
|
||||
func (h *tokenHandler) queueInvitation(
|
||||
ctx context.Context,
|
||||
event *domain.Event,
|
||||
guest *domain.Guest,
|
||||
tk *domain.Token,
|
||||
hostID uuid.UUID,
|
||||
rawToken string,
|
||||
) bool {
|
||||
if h.invitations == nil {
|
||||
return false
|
||||
}
|
||||
if guest.Email == nil || *guest.Email == "" {
|
||||
// Phone-only / nameless guests get no email — host shares the link
|
||||
// manually. Show that on the UI so it's not a silent surprise.
|
||||
return false
|
||||
}
|
||||
hostName := ""
|
||||
if h.users != nil {
|
||||
if host, err := h.users.GetByID(ctx, hostID); err == nil && host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
}
|
||||
evt := natspub.InvitationSend{
|
||||
EventID: event.ID,
|
||||
GuestID: guest.ID,
|
||||
TokenID: tk.ID,
|
||||
GuestName: guest.Name,
|
||||
GuestEmail: *guest.Email,
|
||||
HostName: hostName,
|
||||
EventName: event.Name,
|
||||
Venue: event.Venue,
|
||||
EventDate: event.EventDate,
|
||||
Link: h.invitationLink(rawToken),
|
||||
IssuedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := h.invitations.PublishInvitationSend(ctx, evt); err != nil {
|
||||
h.logger.Warn("publish invitation.send (continuing)", "err", err, "guest_id", guest.ID)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// invitationLink renders the public RSVP URL the guest clicks from their
|
||||
// inbox. publicBaseURL is the externally-reachable host (set via
|
||||
// GG_PUBLIC_BASE_URL); access via /rsvp/<token> is intentional — the
|
||||
// frontend page rsvp/[token].vue catches the raw token.
|
||||
func (h *tokenHandler) invitationLink(raw string) string {
|
||||
base := h.publicBaseURL
|
||||
if base == "" {
|
||||
base = "http://localhost:3000"
|
||||
}
|
||||
return base + "/rsvp/" + raw
|
||||
}
|
||||
|
||||
// bulkIssueRequest is the optional JSON body for the bulk-invite call.
|
||||
// An empty body (or missing GuestIDs) means "every guest on the event
|
||||
// who doesn't already have a token".
|
||||
type bulkIssueRequest struct {
|
||||
GuestIDs []string `json:"guest_ids"`
|
||||
}
|
||||
|
||||
type bulkIssueItemError struct {
|
||||
GuestID string `json:"guest_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// bulkIssueToken is one minted invitation. The raw token is returned so
|
||||
// the host's UI can offer a "copy link" affordance after a bulk send
|
||||
// (especially for guests with no email on file) without making another
|
||||
// round-trip. Same data the per-guest issue endpoint already exposes,
|
||||
// scoped to the host who owns the event.
|
||||
type bulkIssueToken struct {
|
||||
GuestID string `json:"guest_id"`
|
||||
Token string `json:"token"`
|
||||
InvitationQueued bool `json:"invitation_queued"`
|
||||
InvitationLink string `json:"invitation_link"`
|
||||
}
|
||||
|
||||
type bulkIssueResponse struct {
|
||||
Issued int `json:"issued"`
|
||||
Queued int `json:"queued"`
|
||||
SkippedExisting int `json:"skipped_existing"`
|
||||
SkippedNoEmail int `json:"skipped_no_email"`
|
||||
Tokens []bulkIssueToken `json:"tokens,omitempty"`
|
||||
Errors []bulkIssueItemError `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/invitations/bulk — generate tokens for every
|
||||
// eligible guest (or the explicit subset) on the event and queue an
|
||||
// invitation email for those with an address. Best-effort: any per-guest
|
||||
// error is reported in the response and doesn't abort the rest.
|
||||
func (h *tokenHandler) bulkIssue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req bulkIssueRequest
|
||||
if r.ContentLength > 0 {
|
||||
// Body is optional; only decode when something was actually sent.
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var onlyIDs []uuid.UUID
|
||||
if len(req.GuestIDs) > 0 {
|
||||
onlyIDs = make([]uuid.UUID, 0, len(req.GuestIDs))
|
||||
for _, raw := range req.GuestIDs {
|
||||
id, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid guest id: "+raw)
|
||||
return
|
||||
}
|
||||
onlyIDs = append(onlyIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
guests, err := h.guests.ListGuestsForInvitation(r.Context(), eventID, onlyIDs)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guests")
|
||||
return
|
||||
}
|
||||
|
||||
hostName := ""
|
||||
if h.users != nil {
|
||||
if host, err := h.users.GetByID(r.Context(), hostID); err == nil && host != nil {
|
||||
hostName = host.Name
|
||||
}
|
||||
}
|
||||
|
||||
resp := bulkIssueResponse{}
|
||||
for _, g := range guests {
|
||||
if g.HasToken {
|
||||
resp.SkippedExisting++
|
||||
continue
|
||||
}
|
||||
raw, hash, err := h.gen.Generate()
|
||||
if err != nil {
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "mint token failed"})
|
||||
continue
|
||||
}
|
||||
tk, err := h.tokens.Create(r.Context(), storage.CreateTokenParams{
|
||||
GuestID: g.ID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(h.ttl),
|
||||
})
|
||||
if err != nil {
|
||||
// Likely a race against the unique constraint (someone else
|
||||
// issued in parallel) — surface but don't fail the batch.
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: err.Error()})
|
||||
continue
|
||||
}
|
||||
resp.Issued++
|
||||
link := h.invitationLink(raw)
|
||||
tokenInfo := bulkIssueToken{
|
||||
GuestID: g.ID.String(),
|
||||
Token: raw,
|
||||
InvitationLink: link,
|
||||
}
|
||||
|
||||
if g.Email == "" {
|
||||
resp.SkippedNoEmail++
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
evt := natspub.InvitationSend{
|
||||
EventID: event.ID,
|
||||
GuestID: g.ID,
|
||||
TokenID: tk.ID,
|
||||
GuestName: g.Name,
|
||||
GuestEmail: g.Email,
|
||||
HostName: hostName,
|
||||
EventName: event.Name,
|
||||
Venue: event.Venue,
|
||||
EventDate: event.EventDate,
|
||||
Link: h.invitationLink(raw),
|
||||
IssuedAt: time.Now().UTC(),
|
||||
}
|
||||
if h.invitations == nil {
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
if err := h.invitations.PublishInvitationSend(r.Context(), evt); err != nil {
|
||||
h.logger.Warn("publish invitation.send (bulk, continuing)", "err", err, "guest_id", g.ID)
|
||||
resp.Errors = append(resp.Errors, bulkIssueItemError{GuestID: g.ID.String(), Reason: "publish failed"})
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
continue
|
||||
}
|
||||
resp.Queued++
|
||||
tokenInfo.InvitationQueued = true
|
||||
resp.Tokens = append(resp.Tokens, tokenInfo)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type rotateTokenRequest struct {
|
||||
// SendEmail asks the notifier to re-deliver the invitation. False
|
||||
// means "just give me a fresh link" — typical for phone-only guests
|
||||
// where the host shares the new URL via SMS.
|
||||
SendEmail bool `json:"send_email"`
|
||||
}
|
||||
|
||||
// POST /events/{id}/guests/{guest_id}/tokens/rotate — invalidate the
|
||||
// guest's existing invitation link and mint a fresh one. Optionally
|
||||
// re-publishes invitation.send so the notifier re-delivers via email.
|
||||
// The old URL stops working immediately.
|
||||
func (h *tokenHandler) rotate(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
guestID, ok := parseIDParam(w, r, "guest_id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
event, ok := requireEventOwner(w, r, h.events, eventID, hostID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
guest, err := h.guests.Get(r.Context(), guestID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrGuestNotFound) {
|
||||
writeError(w, http.StatusNotFound, "guest not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to load guest")
|
||||
return
|
||||
}
|
||||
if guest.EventID != eventID {
|
||||
writeError(w, http.StatusNotFound, "guest not found in event")
|
||||
return
|
||||
}
|
||||
|
||||
var req rotateTokenRequest
|
||||
if r.ContentLength > 0 {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
raw, hash, err := h.gen.Generate()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to generate token")
|
||||
return
|
||||
}
|
||||
tk, err := h.tokens.RotateForGuest(r.Context(), storage.CreateTokenParams{
|
||||
GuestID: guestID,
|
||||
TokenHash: hash,
|
||||
ExpiresAt: time.Now().UTC().Add(h.ttl),
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("rotate token", "err", err, "guest_id", guestID)
|
||||
writeError(w, http.StatusInternalServerError, "failed to rotate token")
|
||||
return
|
||||
}
|
||||
|
||||
link := h.invitationLink(raw)
|
||||
invitationQueued := false
|
||||
if req.SendEmail {
|
||||
invitationQueued = h.queueInvitation(r.Context(), event, guest, tk, hostID, raw)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, issueTokenResponse{
|
||||
Token: raw,
|
||||
TokenID: tk.ID,
|
||||
Meta: tk,
|
||||
InvitationQueued: invitationQueued,
|
||||
InvitationLink: link,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
)
|
||||
|
||||
type unsubscribeHandler struct {
|
||||
logger *slog.Logger
|
||||
signer *notification.UnsubscribeSigner
|
||||
suppress *notification.SuppressionRepo
|
||||
}
|
||||
|
||||
// GET /unsubscribe/{token} — surface the email address that the token
|
||||
// belongs to so the frontend can show a confirmation page. Honoured even
|
||||
// before the user clicks "Confirm" so they see what's being unsubscribed.
|
||||
func (h *unsubscribeHandler) preview(w http.ResponseWriter, r *http.Request) {
|
||||
if h.signer == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
|
||||
return
|
||||
}
|
||||
email, err := h.signer.Verify(r.PathValue("token"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"email": email})
|
||||
}
|
||||
|
||||
// POST /unsubscribe/{token} — add the email to the suppression list.
|
||||
// Idempotent: clicking the link twice keeps the existing entry.
|
||||
func (h *unsubscribeHandler) confirm(w http.ResponseWriter, r *http.Request) {
|
||||
if h.signer == nil || h.suppress == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "unsubscribe not configured")
|
||||
return
|
||||
}
|
||||
email, err := h.signer.Verify(r.PathValue("token"))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid unsubscribe link")
|
||||
return
|
||||
}
|
||||
if err := h.suppress.Add(r.Context(), email, "user clicked unsubscribe", notification.SuppressionUser); err != nil {
|
||||
h.logger.Error("add suppression", "err", err, "email", email)
|
||||
writeError(w, http.StatusInternalServerError, "failed to unsubscribe")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "unsubscribed", "email": email})
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/domain"
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type userHandler struct {
|
||||
repo *storage.UserRepo
|
||||
}
|
||||
|
||||
type upsertUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// POST /users — idempotent: returns the existing user if the email already
|
||||
// exists, creates one otherwise. This keeps the demo flow simple without
|
||||
// requiring real auth.
|
||||
func (h *userHandler) upsert(w http.ResponseWriter, r *http.Request) {
|
||||
var req upsertUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "email is invalid")
|
||||
return
|
||||
}
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
|
||||
u, err := h.repo.Create(r.Context(), req.Email, req.Name)
|
||||
if err == nil {
|
||||
writeJSON(w, http.StatusCreated, u)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, domain.ErrEmailTaken) {
|
||||
existing, getErr := h.repo.GetByEmail(r.Context(), req.Email)
|
||||
if getErr != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to load user")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, existing)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to create user")
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/notification"
|
||||
)
|
||||
|
||||
// webhookHandler accepts provider status notifications and reflects them
|
||||
// onto the notifications table + suppression list.
|
||||
//
|
||||
// Signature verification is intentionally a TODO until the user provisions
|
||||
// real Twilio + SES creds — verifying against test fixtures alone would
|
||||
// give a false sense of security. The endpoint is therefore *not* exposed
|
||||
// publicly until the deployment is ready.
|
||||
type webhookHandler struct {
|
||||
logger *slog.Logger
|
||||
notifs *notification.Repo
|
||||
suppress *notification.SuppressionRepo
|
||||
}
|
||||
|
||||
// POST /webhooks/twilio/status — Twilio status callback (form-encoded).
|
||||
// Fields we care about: MessageSid, MessageStatus (sent|delivered|
|
||||
// undelivered|failed), ErrorCode, To.
|
||||
func (h *webhookHandler) twilio(w http.ResponseWriter, r *http.Request) {
|
||||
if h.notifs == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// TODO(blockD2): verify X-Twilio-Signature with GG_TWILIO_AUTH_TOKEN.
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid form")
|
||||
return
|
||||
}
|
||||
sid := r.PostForm.Get("MessageSid")
|
||||
status := r.PostForm.Get("MessageStatus")
|
||||
if sid == "" || status == "" {
|
||||
writeError(w, http.StatusBadRequest, "missing MessageSid / MessageStatus")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
switch status {
|
||||
case "delivered":
|
||||
_ = h.notifs.MarkDelivered(ctx, sid)
|
||||
case "undelivered", "failed":
|
||||
_ = h.notifs.MarkBounce(ctx, sid, "permanent")
|
||||
}
|
||||
h.logger.Info("twilio status callback", "sid", sid, "status", status)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// POST /webhooks/ses/notifications — SNS-delivered SES notification (JSON).
|
||||
// Handles the two shapes SES uses: bounce + complaint events. Each event
|
||||
// carries the messageId we stored in provider_message_id and an array of
|
||||
// affected recipients.
|
||||
func (h *webhookHandler) ses(w http.ResponseWriter, r *http.Request) {
|
||||
if h.notifs == nil {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
// TODO(blockD2): verify SNS signature using the cert URL field.
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "read body")
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var envelope struct {
|
||||
Type string `json:"Type"` // "Notification" | "SubscriptionConfirmation"
|
||||
Message string `json:"Message"` // stringified JSON for Notification
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json envelope")
|
||||
return
|
||||
}
|
||||
if envelope.Type == "SubscriptionConfirmation" {
|
||||
// Confirmed by visiting SubscribeURL — manual op-side step.
|
||||
h.logger.Info("ses subscription confirmation received (manual confirm required)")
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if envelope.Message == "" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
var inner struct {
|
||||
NotificationType string `json:"notificationType"` // "Bounce" | "Complaint" | "Delivery"
|
||||
Mail struct {
|
||||
MessageID string `json:"messageId"`
|
||||
} `json:"mail"`
|
||||
Bounce struct {
|
||||
BounceType string `json:"bounceType"` // "Permanent" | "Transient"
|
||||
BouncedRecipients []struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
} `json:"bouncedRecipients"`
|
||||
} `json:"bounce"`
|
||||
Complaint struct {
|
||||
ComplainedRecipients []struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
} `json:"complainedRecipients"`
|
||||
} `json:"complaint"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(envelope.Message), &inner); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid inner json")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
switch inner.NotificationType {
|
||||
case "Bounce":
|
||||
bt := "transient"
|
||||
if inner.Bounce.BounceType == "Permanent" {
|
||||
bt = "permanent"
|
||||
}
|
||||
_ = h.notifs.MarkBounce(ctx, inner.Mail.MessageID, bt)
|
||||
if h.suppress != nil && bt == "permanent" {
|
||||
for _, rcp := range inner.Bounce.BouncedRecipients {
|
||||
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses permanent bounce", notification.SuppressionBounce)
|
||||
}
|
||||
}
|
||||
case "Complaint":
|
||||
_ = h.notifs.MarkComplaint(ctx, inner.Mail.MessageID)
|
||||
if h.suppress != nil {
|
||||
for _, rcp := range inner.Complaint.ComplainedRecipients {
|
||||
_ = h.suppress.Add(ctx, rcp.EmailAddress, "ses complaint", notification.SuppressionComplaint)
|
||||
}
|
||||
}
|
||||
case "Delivery":
|
||||
_ = h.notifs.MarkDelivered(ctx, inner.Mail.MessageID)
|
||||
}
|
||||
h.logger.Info("ses notification", "type", inner.NotificationType, "message_id", inner.Mail.MessageID)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Compile-time check that ctx is unused in package — silences linter on
|
||||
// some Go versions when the file would otherwise import context only for
|
||||
// the handler signatures.
|
||||
var _ = context.Background
|
||||
@@ -0,0 +1,51 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/alchemistkay/guestguard/internal/storage"
|
||||
)
|
||||
|
||||
type wsTicketHandler struct {
|
||||
tickets *wsTicketStore
|
||||
events *storage.EventRepo
|
||||
}
|
||||
|
||||
type wsTicketResponse struct {
|
||||
Ticket string `json:"ticket"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// POST /auth/ws-ticket — requireAuth-protected; body { "event_id": "<uuid>" }.
|
||||
// Returns a single-use ticket valid for ~60 seconds. The frontend appends it
|
||||
// as `?ticket=…` on the WebSocket URL.
|
||||
func (h *wsTicketHandler) issue(w http.ResponseWriter, r *http.Request) {
|
||||
hostID, ok := hostFromContext(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
EventID string `json:"event_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
eventID, ok := parseRawUUID(w, "event_id", req.EventID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, ok := requireEventOwner(w, r, h.events, eventID, hostID); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tok, exp, err := h.tickets.Mint(hostID, eventID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to mint ticket")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, wsTicketResponse{Ticket: tok, ExpiresAt: exp})
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// wsTicketStore mints short-lived single-use tickets that authorise a
|
||||
// WebSocket handshake. The plan calls this option 3 in Block B: cookies
|
||||
// don't reach the WS handshake on cross-origin setups and a JWT in the URL
|
||||
// would leak to logs; a one-shot ticket sidesteps both.
|
||||
//
|
||||
// Block B keeps this in-process. When the API runs more than one replica
|
||||
// this needs to move to Redis (Block C territory).
|
||||
type wsTicketStore struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]wsTicketEntry
|
||||
ttl time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type wsTicketEntry struct {
|
||||
userID uuid.UUID
|
||||
eventID uuid.UUID
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newWSTicketStore(ttl time.Duration) *wsTicketStore {
|
||||
return &wsTicketStore{
|
||||
entries: make(map[string]wsTicketEntry),
|
||||
ttl: ttl,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Mint returns a fresh URL-safe ticket bound to userID + eventID.
|
||||
func (s *wsTicketStore) Mint(userID, eventID uuid.UUID) (string, time.Time, error) {
|
||||
buf := make([]byte, 24)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
tok := base64.RawURLEncoding.EncodeToString(buf)
|
||||
exp := s.now().Add(s.ttl)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked()
|
||||
s.entries[tok] = wsTicketEntry{userID: userID, eventID: eventID, expiresAt: exp}
|
||||
return tok, exp, nil
|
||||
}
|
||||
|
||||
// Consume removes the ticket and returns the bound (userID, eventID) if it
|
||||
// was valid. A ticket is single-use — replaying it fails.
|
||||
func (s *wsTicketStore) Consume(token string) (uuid.UUID, uuid.UUID, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
entry, ok := s.entries[token]
|
||||
if !ok {
|
||||
return uuid.Nil, uuid.Nil, false
|
||||
}
|
||||
delete(s.entries, token)
|
||||
if s.now().After(entry.expiresAt) {
|
||||
return uuid.Nil, uuid.Nil, false
|
||||
}
|
||||
return entry.userID, entry.eventID, true
|
||||
}
|
||||
|
||||
// sweepLocked drops expired entries opportunistically. Cheap because we
|
||||
// usually only hold dozens of tickets at a time.
|
||||
func (s *wsTicketStore) sweepLocked() {
|
||||
now := s.now()
|
||||
for k, v := range s.entries {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(s.entries, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
+23
-3
@@ -89,16 +89,36 @@ func (h *Hub) remove(eventID uuid.UUID, s *subscriber) {
|
||||
}
|
||||
|
||||
type wsHandler struct {
|
||||
logger *slog.Logger
|
||||
hub *Hub
|
||||
logger *slog.Logger
|
||||
hub *Hub
|
||||
tickets *wsTicketStore
|
||||
}
|
||||
|
||||
// GET /ws/events/{id} — dashboard live feed for one event.
|
||||
// GET /ws/events/{id}?ticket=... — dashboard live feed for one event.
|
||||
//
|
||||
// The handshake is authorised by a single-use ticket minted via
|
||||
// POST /auth/ws-ticket (option 3 from the Block B plan). The ticket binds
|
||||
// the connecting user to a specific event_id; we reject if either is
|
||||
// missing or doesn't match the URL path.
|
||||
func (h *wsHandler) handle(w http.ResponseWriter, r *http.Request) {
|
||||
eventID, ok := parseIDParam(w, r, "id")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
rawTicket := r.URL.Query().Get("ticket")
|
||||
if rawTicket == "" {
|
||||
writeError(w, http.StatusUnauthorized, "missing ticket")
|
||||
return
|
||||
}
|
||||
_, ticketEventID, valid := h.tickets.Consume(rawTicket)
|
||||
if !valid {
|
||||
writeError(w, http.StatusUnauthorized, "invalid or expired ticket")
|
||||
return
|
||||
}
|
||||
if ticketEventID != eventID {
|
||||
writeError(w, http.StatusForbidden, "ticket does not match event")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
// In dev the frontend runs on a different origin (localhost:3000 → localhost:8080).
|
||||
|
||||
Reference in New Issue
Block a user