package storage import ( "context" "encoding/json" "errors" "fmt" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/alchemistkay/guestguard/internal/domain" ) type RSVPRepo struct { pool *pgxpool.Pool } func NewRSVPRepo(db *DB) *RSVPRepo { return &RSVPRepo{pool: db.Pool} } type CreateRSVPParams struct { GuestID uuid.UUID Response domain.RSVPResponse PlusOnes int DietaryNotes *string DeviceFingerprint map[string]any IPAddress string RiskScore *int } func (r *RSVPRepo) Create(ctx context.Context, p CreateRSVPParams) (*domain.RSVP, error) { var fpJSON []byte if p.DeviceFingerprint != nil { b, err := json.Marshal(p.DeviceFingerprint) if err != nil { return nil, fmt.Errorf("marshal fingerprint: %w", err) } fpJSON = b } var ip *string if p.IPAddress != "" { ip = &p.IPAddress } const q = ` INSERT INTO rsvps (guest_id, response, plus_ones, dietary_notes, device_fingerprint, ip_address, risk_score) VALUES ($1, $2, $3, $4, $5, $6::inet, $7) RETURNING id, guest_id, response, plus_ones, dietary_notes, submitted_at, device_fingerprint, ip_address::text, risk_score, edit_count ` row := r.pool.QueryRow(ctx, q, p.GuestID, p.Response, p.PlusOnes, p.DietaryNotes, fpJSON, ip, p.RiskScore, ) rs, err := scanRSVP(row) if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == "23505" { return nil, domain.ErrRSVPAlreadySubmitted } return nil, err } return rs, nil } // GetByGuest returns the RSVP submitted by `guestID`, or ErrRSVPNotFound when // none exists yet. Used by /access/{token} to surface the current submission // so the frontend can show an edit form, and by PATCH /rsvp to load the row // being revised. func (r *RSVPRepo) GetByGuest(ctx context.Context, guestID uuid.UUID) (*domain.RSVP, error) { const q = ` SELECT id, guest_id, response, plus_ones, dietary_notes, submitted_at, device_fingerprint, ip_address::text, risk_score, edit_count FROM rsvps WHERE guest_id = $1 ` rs, err := scanRSVP(r.pool.QueryRow(ctx, q, guestID)) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, domain.ErrRSVPNotFound } return nil, err } return rs, nil } type UpdateRSVPParams struct { GuestID uuid.UUID Response domain.RSVPResponse PlusOnes int DietaryNotes *string DeviceFingerprint map[string]any IPAddress string RiskScore *int } // Update applies a revision to the guest's RSVP. The previous values are // snapshotted into rsvp_revisions inside the same transaction so the history // is consistent — either both the snapshot and the new state land, or neither // does. Returns ErrRSVPEditLimitReached if the guest has already hit // MaxRSVPEdits; the row itself is left untouched. func (r *RSVPRepo) Update(ctx context.Context, p UpdateRSVPParams) (*domain.RSVP, error) { var fpJSON []byte if p.DeviceFingerprint != nil { b, err := json.Marshal(p.DeviceFingerprint) if err != nil { return nil, fmt.Errorf("marshal fingerprint: %w", err) } fpJSON = b } var ip *string if p.IPAddress != "" { ip = &p.IPAddress } tx, err := r.pool.Begin(ctx) if err != nil { return nil, err } defer tx.Rollback(ctx) // SELECT ... FOR UPDATE locks the row so two concurrent edits can't both // snapshot the same prior state and both increment edit_count past the cap. var ( rsvpID uuid.UUID prevResp domain.RSVPResponse prevPlusOnes int prevDietary *string editCount int ) err = tx.QueryRow(ctx, ` SELECT id, response, plus_ones, dietary_notes, edit_count FROM rsvps WHERE guest_id = $1 FOR UPDATE `, p.GuestID).Scan(&rsvpID, &prevResp, &prevPlusOnes, &prevDietary, &editCount) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, domain.ErrRSVPNotFound } return nil, err } if editCount >= domain.MaxRSVPEdits { return nil, domain.ErrRSVPEditLimitReached } if _, err := tx.Exec(ctx, ` INSERT INTO rsvp_revisions (rsvp_id, prev_response, prev_plus_ones, prev_dietary) VALUES ($1, $2, $3, $4) `, rsvpID, prevResp, prevPlusOnes, prevDietary); err != nil { return nil, fmt.Errorf("snapshot revision: %w", err) } const upd = ` UPDATE rsvps SET response = $2, plus_ones = $3, dietary_notes = $4, device_fingerprint = COALESCE($5, device_fingerprint), ip_address = COALESCE($6::inet, ip_address), risk_score = COALESCE($7, risk_score), submitted_at = now(), edit_count = edit_count + 1 WHERE guest_id = $1 RETURNING id, guest_id, response, plus_ones, dietary_notes, submitted_at, device_fingerprint, ip_address::text, risk_score, edit_count ` row := tx.QueryRow(ctx, upd, p.GuestID, p.Response, p.PlusOnes, p.DietaryNotes, fpJSON, ip, p.RiskScore, ) rs, err := scanRSVP(row) if err != nil { return nil, err } if err := tx.Commit(ctx); err != nil { return nil, err } return rs, nil } // ListRevisions returns every prior state of an RSVP, newest first. Empty // slice (not nil) when there are no revisions, so the JSON encodes as `[]`. func (r *RSVPRepo) ListRevisions(ctx context.Context, rsvpID uuid.UUID) ([]domain.RSVPRevision, error) { const q = ` SELECT id, rsvp_id, prev_response, prev_plus_ones, prev_dietary, changed_at FROM rsvp_revisions WHERE rsvp_id = $1 ORDER BY changed_at DESC ` rows, err := r.pool.Query(ctx, q, rsvpID) if err != nil { return nil, err } defer rows.Close() out := []domain.RSVPRevision{} for rows.Next() { var rev domain.RSVPRevision if err := rows.Scan( &rev.ID, &rev.RSVPID, &rev.PrevResponse, &rev.PrevPlusOnes, &rev.PrevDietary, &rev.ChangedAt, ); err != nil { return nil, err } out = append(out, rev) } return out, rows.Err() } // RSVPActivity is a denormalised RSVP entry for the activity feed — // includes the guest's name so the API can hand it to the frontend // without a separate lookup. type RSVPActivity struct { GuestID uuid.UUID GuestName string Response string PlusOnes int SubmittedAt time.Time } // ListRecentByEvent returns the most recent RSVPs for an event, newest first. func (r *RSVPRepo) ListRecentByEvent(ctx context.Context, eventID uuid.UUID, limit int) ([]RSVPActivity, error) { if limit <= 0 || limit > 200 { limit = 50 } const q = ` SELECT r.guest_id, g.name, r.response, r.plus_ones, r.submitted_at FROM rsvps r JOIN guests g ON g.id = r.guest_id WHERE g.event_id = $1 ORDER BY r.submitted_at DESC LIMIT $2 ` rows, err := r.pool.Query(ctx, q, eventID, limit) if err != nil { return nil, err } defer rows.Close() var out []RSVPActivity for rows.Next() { var a RSVPActivity if err := rows.Scan(&a.GuestID, &a.GuestName, &a.Response, &a.PlusOnes, &a.SubmittedAt); err != nil { return nil, err } out = append(out, a) } return out, rows.Err() } func scanRSVP(s rowScanner) (*domain.RSVP, error) { var ( rs domain.RSVP fpJSON []byte ip *string ) err := s.Scan( &rs.ID, &rs.GuestID, &rs.Response, &rs.PlusOnes, &rs.DietaryNotes, &rs.SubmittedAt, &fpJSON, &ip, &rs.RiskScore, &rs.EditCount, ) if err != nil { return nil, err } if len(fpJSON) > 0 { _ = json.Unmarshal(fpJSON, &rs.DeviceFingerprint) } if ip != nil { rs.IPAddress = ip } return &rs, nil }