// restore-verify is a post-restore sanity tool. Point it at a freshly // restored Postgres instance and it asserts that the schema and data // are coherent — no orphan rows, expected uniqueness invariants hold, // every table is present. Exits non-zero on any failure so it slots // into a "restore drill" runbook step. // // Usage: // GG_DATABASE_URL=postgres://... ./restore-verify [--verbose] // // The intent is "would I bet my Sunday on this restore being usable?". // Failing fast here keeps a bad restore from being promoted to traffic. package main import ( "context" "errors" "flag" "fmt" "os" "strings" "time" "github.com/jackc/pgx/v5/pgxpool" ) func main() { verbose := flag.Bool("verbose", false, "print every check's result, not just failures") flag.Parse() dsn := os.Getenv("GG_DATABASE_URL") if dsn == "" { fmt.Fprintln(os.Stderr, "restore-verify: GG_DATABASE_URL is required") os.Exit(2) } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() pool, err := pgxpool.New(ctx, dsn) if err != nil { fmt.Fprintf(os.Stderr, "restore-verify: connect: %v\n", err) os.Exit(2) } defer pool.Close() if err := pool.Ping(ctx); err != nil { fmt.Fprintf(os.Stderr, "restore-verify: ping: %v\n", err) os.Exit(2) } fmt.Println("restore-verify: checking", maskDSN(dsn)) fmt.Println() checks := allChecks() var failed []string for _, c := range checks { result, err := c.fn(ctx, pool) if err != nil { failed = append(failed, c.name) fmt.Printf(" ✗ %-50s FAIL: %v\n", c.name, err) continue } if *verbose { fmt.Printf(" ✓ %-50s %s\n", c.name, result) } } fmt.Println() if len(failed) > 0 { fmt.Printf("FAILED: %d check%s — %s\n", len(failed), pluralS(len(failed)), strings.Join(failed, ", ")) os.Exit(1) } fmt.Printf("OK: all %d checks passed\n", len(checks)) } type check struct { name string fn func(ctx context.Context, pool *pgxpool.Pool) (string, error) } func allChecks() []check { return []check{ // --- schema presence --- tableExists("users"), tableExists("events"), tableExists("guests"), tableExists("tokens"), tableExists("rsvps"), tableExists("access_logs"), tableExists("notifications"), tableExists("schema_migrations"), tableExists("email_verification_tokens"), tableExists("password_reset_tokens"), tableExists("refresh_tokens"), tableExists("unsubscribes"), tableExists("subscriptions"), // --- migrations applied --- { name: "schema_migrations: 5+ migrations applied", fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { var n int if err := pool.QueryRow(ctx, `SELECT count(*) FROM schema_migrations`).Scan(&n); err != nil { return "", err } if n < 5 { return "", fmt.Errorf("only %d migrations recorded — incomplete restore", n) } return fmt.Sprintf("%d migrations", n), nil }, }, // --- referential integrity (FK constraints catch most of this, // but a bad logical dump or partial restore can sneak rows in) --- noOrphans("events", "host_id", "users", "id"), noOrphans("guests", "event_id", "events", "id"), noOrphans("tokens", "guest_id", "guests", "id"), noOrphans("rsvps", "guest_id", "guests", "id"), noOrphans("access_logs", "guest_id", "guests", "id"), noOrphans("notifications", "guest_id", "guests", "id"), noOrphans("email_verification_tokens", "user_id", "users", "id"), noOrphans("password_reset_tokens", "user_id", "users", "id"), noOrphans("refresh_tokens", "user_id", "users", "id"), noOrphans("subscriptions", "user_id", "users", "id"), // --- domain invariants --- { name: "users.email is unique (case-insensitive)", fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { var dupes int err := pool.QueryRow(ctx, ` SELECT count(*) FROM ( SELECT lower(email) FROM users GROUP BY lower(email) HAVING count(*) > 1 ) t `).Scan(&dupes) if err != nil { return "", err } if dupes > 0 { return "", fmt.Errorf("%d duplicate email(s) found", dupes) } return "no duplicate emails", nil }, }, { name: "guests with rsvp_response have an existing rsvp row", fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { var n int err := pool.QueryRow(ctx, ` SELECT count(*) FROM rsvps r WHERE NOT EXISTS (SELECT 1 FROM guests g WHERE g.id = r.guest_id) `).Scan(&n) if err != nil { return "", err } if n > 0 { return "", fmt.Errorf("%d rsvp(s) reference a missing guest", n) } return "0 orphan rsvps", nil }, }, { name: "at most one granting subscription per user", fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { var n int err := pool.QueryRow(ctx, ` SELECT count(*) FROM ( SELECT user_id FROM subscriptions WHERE status IN ('active','past_due','trialing') GROUP BY user_id HAVING count(*) > 1 ) t `).Scan(&n) if err != nil { return "", err } if n > 0 { return "", fmt.Errorf("%d user(s) have multiple granting subscriptions", n) } return "all users single-active", nil }, }, // --- soft constraints worth noticing (not failures, but logged) --- { name: "row counts snapshot", fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { parts := []string{} for _, t := range []string{ "users", "events", "guests", "tokens", "rsvps", "access_logs", "notifications", "subscriptions", } { var n int if err := pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", t)).Scan(&n); err != nil { return "", err } parts = append(parts, fmt.Sprintf("%s=%d", t, n)) } return strings.Join(parts, " "), nil }, }, } } func tableExists(name string) check { return check{ name: fmt.Sprintf("table %q exists", name), fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { var exists bool err := pool.QueryRow(ctx, `SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema='public' AND table_name=$1)`, name, ).Scan(&exists) if err != nil { return "", err } if !exists { return "", errors.New("missing") } return "ok", nil }, } } func noOrphans(childTable, childFK, parentTable, parentPK string) check { return check{ name: fmt.Sprintf("no orphans: %s.%s -> %s.%s", childTable, childFK, parentTable, parentPK), fn: func(ctx context.Context, pool *pgxpool.Pool) (string, error) { q := fmt.Sprintf(` SELECT count(*) FROM %s c WHERE c.%s IS NOT NULL AND NOT EXISTS (SELECT 1 FROM %s p WHERE p.%s = c.%s) `, childTable, childFK, parentTable, parentPK, childFK) var n int if err := pool.QueryRow(ctx, q).Scan(&n); err != nil { return "", err } if n > 0 { return "", fmt.Errorf("%d orphan row(s)", n) } return "clean", nil }, } } func maskDSN(dsn string) string { // Crude: redact password between '://user:' and '@'. at := strings.LastIndex(dsn, "@") scheme := strings.Index(dsn, "://") if at < 0 || scheme < 0 || at <= scheme { return dsn } userInfo := dsn[scheme+3 : at] if colon := strings.Index(userInfo, ":"); colon >= 0 { userInfo = userInfo[:colon] + ":****" } return dsn[:scheme+3] + userInfo + dsn[at:] } func pluralS(n int) string { if n == 1 { return "" } return "s" }