1package scram
2
3import (
4 "encoding/base64"
5 "errors"
6 "fmt"
7 "strconv"
8 "strings"
9)
10
11type parser struct {
12 s string // Original casing.
13 lower string // Lower casing, for case-insensitive token consumption.
14 o int // Offset in s/lower.
15}
16
17type parseError struct{ err error }
18
19func (e parseError) Error() string {
20 return e.err.Error()
21}
22
23func (e parseError) Unwrap() error {
24 return e.err
25}
26
27// toLower lower cases bytes that are A-Z. strings.ToLower does too much. and
28// would replace invalid bytes with unicode replacement characters, which would
29// break our requirement that offsets into the original and upper case strings
30// point to the same character.
31func toLower(s string) string {
32 r := []byte(s)
33 for i, c := range r {
34 if c >= 'A' && c <= 'Z' {
35 r[i] = c + 0x20
36 }
37 }
38 return string(r)
39}
40
41func newParser(buf []byte) *parser {
42 s := string(buf)
43 return &parser{s, toLower(s), 0}
44}
45
46// Turn panics of parseError into a descriptive ErrInvalidEncoding. Called with
47// defer by functions that parse.
48func (p *parser) recover(rerr *error) {
49 x := recover()
50 if x == nil {
51 return
52 }
53 err, ok := x.(error)
54 if !ok {
55 panic(x)
56 }
57 var xerr Error
58 if errors.As(err, &xerr) {
59 *rerr = err
60 return
61 }
62 *rerr = fmt.Errorf("%w: %s", ErrInvalidEncoding, err)
63}
64
65func (p *parser) xerrorf(format string, args ...any) {
66 panic(parseError{fmt.Errorf(format, args...)})
67}
68
69func (p *parser) xcheckf(err error, format string, args ...any) {
70 if err != nil {
71 panic(parseError{fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), err)})
72 }
73}
74
75func (p *parser) xempty() {
76 if p.o != len(p.s) {
77 p.xerrorf("leftover data")
78 }
79}
80
81func (p *parser) xnonempty() {
82 if p.o >= len(p.s) {
83 p.xerrorf("unexpected end")
84 }
85}
86
87func (p *parser) xbyte() byte {
88 p.xnonempty()
89 c := p.lower[p.o]
90 p.o++
91 return c
92}
93
94func (p *parser) peek(s string) bool {
95 return strings.HasPrefix(p.lower[p.o:], s)
96}
97
98func (p *parser) take(s string) bool {
99 if p.peek(s) {
100 p.o += len(s)
101 return true
102 }
103 return false
104}
105
106func (p *parser) xtake(s string) {
107 if !p.take(s) {
108 p.xerrorf("expected %q", s)
109 }
110}
111
112func (p *parser) xauthzid() string {
113 p.xtake("a=")
114 return p.xsaslname()
115}
116
117func (p *parser) xusername() string {
118 p.xtake("n=")
119 return p.xsaslname()
120}
121
122func (p *parser) xnonce() string {
123 p.xtake("r=")
124 o := p.o
125 for ; o < len(p.s); o++ {
126 c := p.s[o]
127 if c <= ' ' || c >= 0x7f || c == ',' {
128 break
129 }
130 }
131 if o == p.o {
132 p.xerrorf("empty nonce")
133 }
134 r := p.s[p.o:o]
135 p.o = o
136 return r
137}
138
139func (p *parser) xattrval() {
140 c := p.xbyte()
141 if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z') {
142 p.xerrorf("expected alpha for attr-val")
143 }
144 p.xtake("=")
145 p.xvalue()
146}
147
148func (p *parser) xvalue() string {
149 for o, c := range p.s[p.o:] {
150 if c == 0 || c == ',' {
151 if o == 0 {
152 p.xerrorf("invalid empty value")
153 }
154 r := p.s[p.o : p.o+o]
155 p.o = o
156 return r
157 }
158 }
159 p.xnonempty()
160 r := p.s[p.o:]
161 p.o = len(p.s)
162 return r
163}
164
165func (p *parser) xbase64() []byte {
166 o := p.o
167 for ; o < len(p.s); o++ {
168 c := p.s[o]
169 if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '/' || c == '+' || c == '=') {
170 break
171 }
172 }
173 buf, err := base64.StdEncoding.DecodeString(p.s[p.o:o])
174 p.xcheckf(err, "decoding base64")
175 p.o = o
176 return buf
177}
178
179func (p *parser) xsaslname() string {
180 var esc string
181 var is bool
182 var r string
183 for o, c := range p.s[p.o:] {
184 if c == 0 || c == ',' {
185 if is {
186 p.xerrorf("saslname unexpected end")
187 }
188 if o == 0 {
189 p.xerrorf("saslname cannot be empty")
190 }
191 p.o += o
192 return r
193 }
194 if is {
195 esc += string(c)
196 if len(esc) < 2 {
197 continue
198 }
199 switch esc {
200 case "2c", "2C":
201 r += ","
202 case "3d", "3D":
203 r += "="
204 default:
205 p.xerrorf("bad escape %q in saslanem", esc)
206 }
207 is = false
208 esc = ""
209 continue
210 } else if c == '=' {
211 is = true
212 continue
213 }
214 r += string(c)
215 }
216 if is {
217 p.xerrorf("saslname unexpected end")
218 }
219 if r == "" {
220 p.xerrorf("saslname cannot be empty")
221 }
222 p.o = len(p.s)
223 return r
224}
225
226// ../rfc/5802:889
227func (p *parser) xcbname() string {
228 o := p.o
229 for ; o < len(p.s); o++ {
230 c := p.s[o]
231 if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '.' || c == '-' {
232 continue
233 }
234 break
235 }
236 if o == p.o {
237 p.xerrorf("empty channel binding name")
238 }
239 r := p.s[p.o:o]
240 p.o = o
241 return string(r)
242}
243
244func (p *parser) xchannelBinding() []byte {
245 p.xtake("c=")
246 return p.xbase64()
247}
248
249func (p *parser) xproof() []byte {
250 p.xtake("p=")
251 return p.xbase64()
252}
253
254func (p *parser) xsalt() []byte {
255 p.xtake("s=")
256 return p.xbase64()
257}
258
259func (p *parser) xtakefn1(fn func(rune, int) bool) string {
260 for o, c := range p.s[p.o:] {
261 if !fn(c, o) {
262 if o == 0 {
263 p.xerrorf("non-empty match required")
264 }
265 r := p.s[p.o : p.o+o]
266 p.o += o
267 return r
268 }
269 }
270 p.xnonempty()
271 r := p.s[p.o:]
272 p.o = len(p.s)
273 return r
274}
275
276func (p *parser) xiterations() int {
277 p.xtake("i=")
278 digits := p.xtakefn1(func(c rune, i int) bool {
279 return c >= '1' && c <= '9' || i > 0 && c == '0'
280 })
281 v, err := strconv.ParseInt(digits, 10, 32)
282 p.xcheckf(err, "parsing int")
283 return int(v)
284}
285