package ratelimit import ( "context" "testing" "time" "github.com/alicebob/miniredis/v2" "github.com/redis/go-redis/v9" ) func newTestLimiter(t *testing.T) (*Limiter, *miniredis.Miniredis) { t.Helper() mr, err := miniredis.Run() if err != nil { t.Fatalf("miniredis: %v", err) } t.Cleanup(mr.Close) cli := redis.NewClient(&redis.Options{Addr: mr.Addr()}) t.Cleanup(func() { _ = cli.Close() }) return New(cli, "test"), mr } func TestLimiterAllowsBelowLimit(t *testing.T) { l, _ := newTestLimiter(t) ctx := context.Background() for i := 1; i <= 3; i++ { r, err := l.Allow(ctx, "signup", "1.2.3.4", 3, time.Minute) if err != nil { t.Fatalf("allow #%d: %v", i, err) } if !r.Allowed { t.Fatalf("hit %d should be allowed, got %+v", i, r) } if r.Count != i { t.Fatalf("hit %d count: got %d want %d", i, r.Count, i) } } } func TestLimiterBlocksAtLimit(t *testing.T) { l, _ := newTestLimiter(t) ctx := context.Background() for i := 0; i < 3; i++ { if _, err := l.Allow(ctx, "signup", "ip", 3, time.Minute); err != nil { t.Fatal(err) } } r, err := l.Allow(ctx, "signup", "ip", 3, time.Minute) if err != nil { t.Fatal(err) } if r.Allowed { t.Fatalf("4th hit should be blocked: %+v", r) } if r.RetryAfter <= 0 || r.RetryAfter > time.Minute { t.Fatalf("retry-after out of range: %v", r.RetryAfter) } } func TestLimiterWindowSlides(t *testing.T) { l, mr := newTestLimiter(t) ctx := context.Background() // Inject a controllable clock so we can advance time in miniredis + // the limiter consistently. base := time.Unix(1_700_000_000, 0) l.now = func() time.Time { return base } for i := 0; i < 3; i++ { if _, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute); err != nil { t.Fatal(err) } } // Slide past the window. miniredis honours TTLs we already set so // FastForward is the trustworthy primitive here. l.now = func() time.Time { return base.Add(2 * time.Minute) } mr.FastForward(2 * time.Minute) r, err := l.Allow(ctx, "rsvp", "tok", 3, time.Minute) if err != nil { t.Fatal(err) } if !r.Allowed { t.Fatalf("expected allow after window: %+v", r) } } func TestLimiterIsolatesKeys(t *testing.T) { l, _ := newTestLimiter(t) ctx := context.Background() // Exhaust budget for one key — others should not be affected. for i := 0; i < 2; i++ { if _, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute); err != nil { t.Fatal(err) } } blocked, err := l.Allow(ctx, "login", "a@x.com", 2, time.Minute) if err != nil { t.Fatal(err) } if blocked.Allowed { t.Fatal("key a should be blocked") } other, err := l.Allow(ctx, "login", "b@x.com", 2, time.Minute) if err != nil { t.Fatal(err) } if !other.Allowed { t.Fatalf("unrelated key b should still be allowed: %+v", other) } }