1// Package ratelimit provides a simple window-based rate limiter.
2package ratelimit
3
4import (
5 "net"
6 "sync"
7 "time"
8)
9
10// Limiter is a simple rate limiter with one or more fixed windows, e.g. the
11// last minute/hour/day/week, working on three classes/subnets of an IP.
12type Limiter struct {
13 sync.Mutex
14 WindowLimits []WindowLimit
15 ipmasked [3][16]byte
16}
17
18// WindowLimit holds counters for one window, with limits for each IP class/subnet.
19type WindowLimit struct {
20 Window time.Duration
21 Limits [3]int64 // For "ipmasked1" through "ipmasked3".
22 Time uint32 // Time/Window.
23 Counts map[struct {
24 Index uint8
25 IPMasked [16]byte
26 }]int64
27}
28
29// Add attempts to consume "n" items from the rate limiter. If the total for this
30// key and this interval would exceed limit, "n" is not counted and false is
31// returned. If now represents a different time interval, all counts are reset.
32func (l *Limiter) Add(ip net.IP, tm time.Time, n int64) bool {
33 return l.checkAdd(true, ip, tm, n)
34}
35
36// CanAdd returns if n could be added to the limiter.
37func (l *Limiter) CanAdd(ip net.IP, tm time.Time, n int64) bool {
38 return l.checkAdd(false, ip, tm, n)
39}
40
41func (l *Limiter) checkAdd(add bool, ip net.IP, tm time.Time, n int64) bool {
42 l.Lock()
43 defer l.Unlock()
44
45 // First check.
46 for i, pl := range l.WindowLimits {
47 t := uint32(tm.UnixNano() / int64(pl.Window))
48
49 if t > pl.Time || pl.Counts == nil {
50 l.WindowLimits[i].Time = t
51 pl.Counts = map[struct {
52 Index uint8
53 IPMasked [16]byte
54 }]int64{} // Used below.
55 l.WindowLimits[i].Counts = pl.Counts
56 }
57
58 for j := 0; j < 3; j++ {
59 if i == 0 {
60 l.ipmasked[j] = l.maskIP(j, ip)
61 }
62
63 v := pl.Counts[struct {
64 Index uint8
65 IPMasked [16]byte
66 }{uint8(j), l.ipmasked[j]}]
67 if v+n > pl.Limits[j] {
68 return false
69 }
70 }
71 }
72 if !add {
73 return true
74 }
75 // Finally record.
76 for _, pl := range l.WindowLimits {
77 for j := 0; j < 3; j++ {
78 pl.Counts[struct {
79 Index uint8
80 IPMasked [16]byte
81 }{uint8(j), l.ipmasked[j]}] += n
82 }
83 }
84 return true
85}
86
87// Reset sets the counter to 0 for key and ip, and subtracts from the ipmasked counts.
88func (l *Limiter) Reset(ip net.IP, tm time.Time) {
89 l.Lock()
90 defer l.Unlock()
91
92 // Prepare masked ip's.
93 for i := 0; i < 3; i++ {
94 l.ipmasked[i] = l.maskIP(i, ip)
95 }
96
97 for _, pl := range l.WindowLimits {
98 t := uint32(tm.UnixNano() / int64(pl.Window))
99 if t != pl.Time || pl.Counts == nil {
100 continue
101 }
102 var n int64
103 for j := 0; j < 3; j++ {
104 k := struct {
105 Index uint8
106 IPMasked [16]byte
107 }{uint8(j), l.ipmasked[j]}
108 if j == 0 {
109 n = pl.Counts[k]
110 }
111 if pl.Counts != nil {
112 pl.Counts[k] -= n
113 }
114 }
115 }
116}
117
118func (l *Limiter) maskIP(i int, ip net.IP) [16]byte {
119 isv4 := ip.To4() != nil
120
121 var ipmasked net.IP
122 if isv4 {
123 switch i {
124 case 0:
125 ipmasked = ip
126 case 1:
127 ipmasked = ip.Mask(net.CIDRMask(26, 32))
128 case 2:
129 ipmasked = ip.Mask(net.CIDRMask(21, 32))
130 default:
131 panic("missing case for maskip ipv4")
132 }
133 } else {
134 switch i {
135 case 0:
136 ipmasked = ip.Mask(net.CIDRMask(64, 128))
137 case 1:
138 ipmasked = ip.Mask(net.CIDRMask(48, 128))
139 case 2:
140 ipmasked = ip.Mask(net.CIDRMask(32, 128))
141 default:
142 panic("missing case for masking ipv6")
143 }
144 }
145 return *(*[16]byte)(ipmasked.To16())
146}
147