1// Package ratelimit provides a simple window-based rate limiter.
10// Limiter is a simple rate limiter with one or more fixed windows, e.g. the
11// last minute/hour/day/week, working on three classes/subnets of an IP.
14 WindowLimits []WindowLimit
18// WindowLimit holds counters for one window, with limits for each IP class/subnet.
19type WindowLimit struct {
21 Limits [3]int64 // For "ipmasked1" through "ipmasked3".
22 Time uint32 // Time/Window.
29// Add attempts to consume "n" items from the rate limiter. If the total for this
30// key and this interval would exceed limit, "n" is not counted and false is
31// returned. If now represents a different time interval, all counts are reset.
32func (l *Limiter) Add(ip net.IP, tm time.Time, n int64) bool {
33 return l.checkAdd(true, ip, tm, n)
36// CanAdd returns if n could be added to the limiter.
37func (l *Limiter) CanAdd(ip net.IP, tm time.Time, n int64) bool {
38 return l.checkAdd(false, ip, tm, n)
41func (l *Limiter) checkAdd(add bool, ip net.IP, tm time.Time, n int64) bool {
46 for i, pl := range l.WindowLimits {
47 t := uint32(tm.UnixNano() / int64(pl.Window))
49 if t > pl.Time || pl.Counts == nil {
50 l.WindowLimits[i].Time = t
51 pl.Counts = map[struct {
54 }]int64{} // Used below.
55 l.WindowLimits[i].Counts = pl.Counts
58 for j := 0; j < 3; j++ {
60 l.ipmasked[j] = l.maskIP(j, ip)
63 v := pl.Counts[struct {
66 }{uint8(j), l.ipmasked[j]}]
67 if v+n > pl.Limits[j] {
76 for _, pl := range l.WindowLimits {
77 for j := 0; j < 3; j++ {
81 }{uint8(j), l.ipmasked[j]}] += n
87// Reset sets the counter to 0 for key and ip, and subtracts from the ipmasked counts.
88func (l *Limiter) Reset(ip net.IP, tm time.Time) {
92 // Prepare masked ip's.
93 for i := 0; i < 3; i++ {
94 l.ipmasked[i] = l.maskIP(i, ip)
97 for _, pl := range l.WindowLimits {
98 t := uint32(tm.UnixNano() / int64(pl.Window))
99 if t != pl.Time || pl.Counts == nil {
103 for j := 0; j < 3; j++ {
107 }{uint8(j), l.ipmasked[j]}
111 if pl.Counts != nil {
118func (l *Limiter) maskIP(i int, ip net.IP) [16]byte {
119 isv4 := ip.To4() != nil
127 ipmasked = ip.Mask(net.CIDRMask(26, 32))
129 ipmasked = ip.Mask(net.CIDRMask(21, 32))
131 panic("missing case for maskip ipv4")
136 ipmasked = ip.Mask(net.CIDRMask(64, 128))
138 ipmasked = ip.Mask(net.CIDRMask(48, 128))
140 ipmasked = ip.Mask(net.CIDRMask(32, 128))
142 panic("missing case for masking ipv6")
145 return *(*[16]byte)(ipmasked.To16())