// Package ratelimit implements a sliding-window rate limiter backed by // Redis sorted sets. Each call is atomic via a Lua script: it sweeps // entries older than the window, returns the current count, and (when // under the limit) records the new hit. Block C's "INCR + EXPIRE or a Lua // script for atomicity" requirement. package ratelimit import ( "context" "crypto/rand" "encoding/hex" "errors" "fmt" "time" "github.com/redis/go-redis/v9" ) // Result is the outcome of one Allow check. type Result struct { Allowed bool Count int // current count within the window (post-increment when allowed) Limit int RetryAfter time.Duration // populated when Allowed=false } // Limiter checks rate-limit budgets against Redis. type Limiter struct { client *redis.Client script *redis.Script prefix string now func() time.Time } // New builds a Limiter against the given Redis client. The prefix namespaces // all keys (defaults to "rl"). func New(client *redis.Client, prefix string) *Limiter { if prefix == "" { prefix = "rl" } return &Limiter{ client: client, script: redis.NewScript(slidingWindowScript), prefix: prefix, now: time.Now, } } // Allow consumes one unit of budget under (name, key) against `limit` events // per `window`. Returns Allowed=true and the new count, or Allowed=false // with RetryAfter set to roughly the duration until the oldest hit ages out. func (l *Limiter) Allow(ctx context.Context, name, key string, limit int, window time.Duration) (Result, error) { if limit <= 0 { return Result{}, errors.New("ratelimit: limit must be positive") } if window <= 0 { return Result{}, errors.New("ratelimit: window must be positive") } member, err := randomMember() if err != nil { return Result{}, err } now := l.now().UnixMilli() windowMS := window.Milliseconds() redisKey := fmt.Sprintf("%s:%s:%s", l.prefix, name, key) out, err := l.script.Run(ctx, l.client, []string{redisKey}, now, windowMS, limit, member, ).Int64Slice() if err != nil { return Result{}, fmt.Errorf("ratelimit: redis: %w", err) } if len(out) != 3 { return Result{}, fmt.Errorf("ratelimit: bad lua reply: %v", out) } r := Result{ Count: int(out[1]), Limit: limit, } if out[0] == 0 { r.Allowed = true } else { r.RetryAfter = time.Duration(out[2]) * time.Millisecond if r.RetryAfter <= 0 { r.RetryAfter = time.Second } } return r, nil } func randomMember() (string, error) { var buf [12]byte if _, err := rand.Read(buf[:]); err != nil { return "", err } return hex.EncodeToString(buf[:]), nil } // Sliding-window check + record, atomic in Redis. // // KEYS[1] = bucket key // ARGV[1] = now (unix ms) // ARGV[2] = window (ms) // ARGV[3] = limit // ARGV[4] = unique member to insert when allowed // returns { blocked, count, retryAfterMs } const slidingWindowScript = ` local key = KEYS[1] local now = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local limit = tonumber(ARGV[3]) local member = ARGV[4] local cutoff = now - window redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff) local count = redis.call('ZCARD', key) if count >= limit then local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') local retry = window if oldest[2] then retry = (tonumber(oldest[2]) + window) - now if retry < 1 then retry = 1 end end return {1, count, retry} end redis.call('ZADD', key, now, member) redis.call('PEXPIRE', key, window) return {0, count + 1, 0} `