package storage import ( "context" "errors" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) // Subscription mirrors the subscriptions table row. Stored as a thin // projection of the Stripe state — we don't try to mirror every field, // just what middleware + handlers need to decide access. type Subscription struct { ID uuid.UUID UserID uuid.UUID StripeCustomerID string StripeSubscriptionID *string Tier string Status string CurrentPeriodEnd *time.Time CancelAtPeriodEnd bool CreatedAt time.Time UpdatedAt time.Time } // ErrSubscriptionNotFound is returned when no row matches the lookup. var ErrSubscriptionNotFound = errors.New("subscription not found") type SubscriptionRepo struct { pool *pgxpool.Pool } func NewSubscriptionRepo(db *DB) *SubscriptionRepo { return &SubscriptionRepo{pool: db.Pool} } const subscriptionColumns = ` id, user_id, stripe_customer_id, stripe_subscription_id, tier, status, current_period_end, cancel_at_period_end, created_at, updated_at ` // GetActiveByUser returns the user's currently-granting subscription // (active / trialing / past_due). Returns ErrSubscriptionNotFound when // the user has no row at all — caller treats that as free tier. func (r *SubscriptionRepo) GetActiveByUser(ctx context.Context, userID uuid.UUID) (*Subscription, error) { const q = ` SELECT ` + subscriptionColumns + ` FROM subscriptions WHERE user_id = $1 AND status IN ('active','past_due','trialing') ORDER BY updated_at DESC LIMIT 1 ` return r.scanOne(ctx, q, userID) } // GetByCustomer fetches by Stripe customer id — webhooks use this since // the event payload identifies the customer, not the user. func (r *SubscriptionRepo) GetByCustomer(ctx context.Context, customerID string) (*Subscription, error) { const q = ` SELECT ` + subscriptionColumns + ` FROM subscriptions WHERE stripe_customer_id = $1 ORDER BY updated_at DESC LIMIT 1 ` return r.scanOne(ctx, q, customerID) } // FindCustomerID returns the Stripe customer id we've already created // for this user, or "" if none exists yet. Avoids creating duplicate // Stripe customers across checkout sessions. func (r *SubscriptionRepo) FindCustomerID(ctx context.Context, userID uuid.UUID) (string, error) { const q = ` SELECT stripe_customer_id FROM subscriptions WHERE user_id = $1 ORDER BY created_at ASC LIMIT 1 ` var id string if err := r.pool.QueryRow(ctx, q, userID).Scan(&id); err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", nil } return "", err } return id, nil } // UpsertParams collects everything an upsert needs. Pointer types denote // "skip writing this column" (used when a webhook only carries partial // data — we never want to clobber tier or period info we don't have). type UpsertParams struct { UserID uuid.UUID StripeCustomerID string StripeSubscriptionID *string Tier *string Status *string CurrentPeriodEnd *time.Time CancelAtPeriodEnd *bool } // Upsert inserts a new row or updates an existing one keyed by // stripe_customer_id. Used by both the checkout-success handler and the // webhook subscription-lifecycle handler. func (r *SubscriptionRepo) Upsert(ctx context.Context, p UpsertParams) (*Subscription, error) { const q = ` INSERT INTO subscriptions ( user_id, stripe_customer_id, stripe_subscription_id, tier, status, current_period_end, cancel_at_period_end ) VALUES ( $1, $2, $3, COALESCE($4, 'free'), COALESCE($5, 'incomplete'), $6, COALESCE($7, FALSE) ) ON CONFLICT (id) DO NOTHING RETURNING ` + subscriptionColumns + ` ` row := r.pool.QueryRow(ctx, q, p.UserID, p.StripeCustomerID, p.StripeSubscriptionID, p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd, ) sub, err := scanSubscription(row) if err == nil { return sub, nil } if !errors.Is(err, pgx.ErrNoRows) { return nil, err } // Race or duplicate insert — fall back to an explicit update on the // stripe_customer_id (the FK to Stripe's source of truth). const upd = ` UPDATE subscriptions SET stripe_subscription_id = COALESCE($3, stripe_subscription_id), tier = COALESCE($4, tier), status = COALESCE($5, status), current_period_end = COALESCE($6, current_period_end), cancel_at_period_end = COALESCE($7, cancel_at_period_end), updated_at = now() WHERE user_id = $1 AND stripe_customer_id = $2 RETURNING ` + subscriptionColumns + ` ` row = r.pool.QueryRow(ctx, upd, p.UserID, p.StripeCustomerID, p.StripeSubscriptionID, p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd, ) return scanSubscription(row) } // UpdateByCustomer patches the subscription row keyed by Stripe customer // id. Used by webhooks where we have the customer reference but not // always the user id. func (r *SubscriptionRepo) UpdateByCustomer(ctx context.Context, customerID string, p UpsertParams) error { const q = ` UPDATE subscriptions SET stripe_subscription_id = COALESCE($2, stripe_subscription_id), tier = COALESCE($3, tier), status = COALESCE($4, status), current_period_end = COALESCE($5, current_period_end), cancel_at_period_end = COALESCE($6, cancel_at_period_end), updated_at = now() WHERE stripe_customer_id = $1 ` _, err := r.pool.Exec(ctx, q, customerID, p.StripeSubscriptionID, p.Tier, p.Status, p.CurrentPeriodEnd, p.CancelAtPeriodEnd, ) return err } // CountEventsInCurrentMonth returns how many events the user has created // since the 1st of the current UTC month. Used for free-tier "1 event / // month" and Pro-tier "10 events / month" enforcement. func (r *SubscriptionRepo) CountEventsInCurrentMonth(ctx context.Context, userID uuid.UUID) (int, error) { const q = ` SELECT count(*) FROM events WHERE host_id = $1 AND created_at >= date_trunc('month', now() AT TIME ZONE 'UTC') ` var n int if err := r.pool.QueryRow(ctx, q, userID).Scan(&n); err != nil { return 0, err } return n, nil } // CountGuestsByEvent returns the current guest count for an event. Used // for per-event guest cap enforcement. func (r *SubscriptionRepo) CountGuestsByEvent(ctx context.Context, eventID uuid.UUID) (int, error) { var n int if err := r.pool.QueryRow(ctx, `SELECT count(*) FROM guests WHERE event_id = $1`, eventID, ).Scan(&n); err != nil { return 0, err } return n, nil } // CountCollaboratorsByEvent counts non-owner collaborators on an event // (accepted + still-pending invites). Used for the Tier 2 collaborator // quota — owners aren't a shared seat, they pay separately, so they're // excluded. func (r *SubscriptionRepo) CountCollaboratorsByEvent(ctx context.Context, eventID uuid.UUID) (int, error) { var n int if err := r.pool.QueryRow(ctx, ` SELECT (SELECT count(*) FROM event_collaborators WHERE event_id = $1 AND role <> 'owner') + (SELECT count(*) FROM collaborator_invites WHERE event_id = $1 AND consumed_at IS NULL AND expires_at > now() AND role <> 'owner') `, eventID).Scan(&n); err != nil { return 0, err } return n, nil } func (r *SubscriptionRepo) scanOne(ctx context.Context, q string, args ...any) (*Subscription, error) { sub, err := scanSubscription(r.pool.QueryRow(ctx, q, args...)) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrSubscriptionNotFound } return nil, err } return sub, nil } func scanSubscription(s rowScanner) (*Subscription, error) { var sub Subscription if err := s.Scan( &sub.ID, &sub.UserID, &sub.StripeCustomerID, &sub.StripeSubscriptionID, &sub.Tier, &sub.Status, &sub.CurrentPeriodEnd, &sub.CancelAtPeriodEnd, &sub.CreatedAt, &sub.UpdatedAt, ); err != nil { return nil, err } return &sub, nil }