package storage import ( "context" "embed" "fmt" "io/fs" "sort" "strings" "time" "github.com/jackc/pgx/v5/pgxpool" ) //go:embed migrations/*.sql var migrationsFS embed.FS type DB struct { Pool *pgxpool.Pool } func NewDB(ctx context.Context, dsn string) (*DB, error) { cfg, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse dsn: %w", err) } cfg.MaxConnLifetime = 30 * time.Minute cfg.MaxConns = 10 pool, err := pgxpool.NewWithConfig(ctx, cfg) if err != nil { return nil, fmt.Errorf("connect pool: %w", err) } pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() if err := pool.Ping(pingCtx); err != nil { pool.Close() return nil, fmt.Errorf("ping db: %w", err) } return &DB{Pool: pool}, nil } func (db *DB) Close() { db.Pool.Close() } func (db *DB) Migrate(ctx context.Context) error { _, err := db.Pool.Exec(ctx, ` CREATE TABLE IF NOT EXISTS schema_migrations ( version VARCHAR(255) PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT now() ) `) if err != nil { return fmt.Errorf("create migrations table: %w", err) } entries, err := fs.ReadDir(migrationsFS, "migrations") if err != nil { return fmt.Errorf("read migrations: %w", err) } type migration struct { version string path string } var ups []migration for _, e := range entries { name := e.Name() if !strings.HasSuffix(name, ".up.sql") { continue } version := strings.TrimSuffix(name, ".up.sql") ups = append(ups, migration{version: version, path: "migrations/" + name}) } sort.Slice(ups, func(i, j int) bool { return ups[i].version < ups[j].version }) for _, m := range ups { var exists bool err := db.Pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version=$1)", m.version, ).Scan(&exists) if err != nil { return fmt.Errorf("check migration %s: %w", m.version, err) } if exists { continue } sqlBytes, err := migrationsFS.ReadFile(m.path) if err != nil { return fmt.Errorf("read %s: %w", m.path, err) } tx, err := db.Pool.Begin(ctx) if err != nil { return fmt.Errorf("begin tx for %s: %w", m.version, err) } if _, err := tx.Exec(ctx, string(sqlBytes)); err != nil { _ = tx.Rollback(ctx) return fmt.Errorf("apply %s: %w", m.version, err) } if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations(version) VALUES($1)", m.version); err != nil { _ = tx.Rollback(ctx) return fmt.Errorf("record %s: %w", m.version, err) } if err := tx.Commit(ctx); err != nil { return fmt.Errorf("commit %s: %w", m.version, err) } } return nil }