1package scram
2
3import (
4 "crypto/ed25519"
5 cryptorand "crypto/rand"
6 "crypto/sha1"
7 "crypto/sha256"
8 "crypto/tls"
9 "crypto/x509"
10 "encoding/base64"
11 "errors"
12 "hash"
13 "math/big"
14 "net"
15 "testing"
16 "time"
17)
18
19func base64Decode(s string) []byte {
20 buf, err := base64.StdEncoding.DecodeString(s)
21 if err != nil {
22 panic("bad base64")
23 }
24 return buf
25}
26
27func tcheck(t *testing.T, err error, msg string) {
28 t.Helper()
29 if err != nil {
30 t.Fatalf("%s: %s", msg, err)
31 }
32}
33
34func TestSCRAMSHA1Server(t *testing.T) {
35 // Test vector from ../rfc/5802:496
36 salt := base64Decode("QSXCR+Q6sek8bf92")
37 saltedPassword := SaltPassword(sha1.New, "pencil", salt, 4096)
38
39 server, err := NewServer(sha1.New, []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL"), nil, false)
40 server.serverNonceOverride = "3rfcNHYJY1ZVvWVs7j"
41 tcheck(t, err, "newserver")
42 resp, err := server.ServerFirst(4096, salt)
43 tcheck(t, err, "server first")
44 if resp != "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096" {
45 t.Fatalf("bad server first")
46 }
47 serverFinal, err := server.Finish([]byte("c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts="), saltedPassword)
48 tcheck(t, err, "finish")
49 if serverFinal != "v=rmF9pqV8S7suAoZWja4dJRkFsKQ=" {
50 t.Fatalf("bad server final")
51 }
52}
53
54func TestSCRAMSHA256Server(t *testing.T) {
55 // Test vector from ../rfc/7677:122
56 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
57 saltedPassword := SaltPassword(sha256.New, "pencil", salt, 4096)
58
59 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
60 server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
61 tcheck(t, err, "newserver")
62 resp, err := server.ServerFirst(4096, salt)
63 tcheck(t, err, "server first")
64 if resp != "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" {
65 t.Fatalf("bad server first")
66 }
67 serverFinal, err := server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
68 tcheck(t, err, "finish")
69 if serverFinal != "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" {
70 t.Fatalf("bad server final")
71 }
72}
73
74// Bad attempt with wrong password.
75func TestScramServerBadPassword(t *testing.T) {
76 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
77 saltedPassword := SaltPassword(sha256.New, "marker", salt, 4096)
78
79 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
80 server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
81 tcheck(t, err, "newserver")
82 _, err = server.ServerFirst(4096, salt)
83 tcheck(t, err, "server first")
84 _, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
85 if !errors.Is(err, ErrInvalidProof) {
86 t.Fatalf("got %v, expected ErrInvalidProof", err)
87 }
88}
89
90// Bad attempt with different number of rounds.
91func TestScramServerBadIterations(t *testing.T) {
92 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
93 saltedPassword := SaltPassword(sha256.New, "pencil", salt, 2048)
94
95 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
96 server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
97 tcheck(t, err, "newserver")
98 _, err = server.ServerFirst(4096, salt)
99 tcheck(t, err, "server first")
100 _, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
101 if !errors.Is(err, ErrInvalidProof) {
102 t.Fatalf("got %v, expected ErrInvalidProof", err)
103 }
104}
105
106// Another attempt but with a randomly different nonce.
107func TestScramServerBad(t *testing.T) {
108 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
109 saltedPassword := SaltPassword(sha256.New, "pencil", salt, 4096)
110
111 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
112 tcheck(t, err, "newserver")
113 _, err = server.ServerFirst(4096, salt)
114 tcheck(t, err, "server first")
115 _, err = server.Finish([]byte("c=biws,r="+server.nonce+",p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
116 if !errors.Is(err, ErrInvalidProof) {
117 t.Fatalf("got %v, expected ErrInvalidProof", err)
118 }
119}
120
121func TestScramClient(t *testing.T) {
122 c := NewClient(sha256.New, "user", "", false, nil)
123 c.clientNonce = "rOprNGfwEbeRWgbNEkqO"
124 clientFirst, err := c.ClientFirst()
125 tcheck(t, err, "ClientFirst")
126 if clientFirst != "n,,n=user,r=rOprNGfwEbeRWgbNEkqO" {
127 t.Fatalf("bad clientFirst")
128 }
129 clientFinal, err := c.ServerFirst([]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"), "pencil")
130 tcheck(t, err, "ServerFirst")
131 if clientFinal != "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=" {
132 t.Fatalf("bad clientFinal")
133 }
134 err = c.ServerFinal([]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="))
135 tcheck(t, err, "ServerFinal")
136}
137
138func TestScram(t *testing.T) {
139 runHash := func(h func() hash.Hash, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string, noServerPlus bool, clientcs, servercs *tls.ConnectionState) {
140 t.Helper()
141
142 defer func() {
143 x := recover()
144 if x == nil || x == "" {
145 return
146 }
147 panic(x)
148 }()
149
150 // check err is either nil or the expected error. if the expected error, panic to abort the authentication session.
151 xerr := func(err error, msg string) {
152 t.Helper()
153 if err != nil && !errors.Is(err, expErr) {
154 t.Fatalf("%s: got %v, expected %v", msg, err, expErr)
155 }
156 if err != nil {
157 panic("") // Abort test.
158 }
159 }
160
161 salt := MakeRandom()
162 saltedPassword := SaltPassword(h, password, salt, iterations)
163
164 client := NewClient(h, username, "", noServerPlus, clientcs)
165 client.clientNonce = clientNonce
166 clientFirst, err := client.ClientFirst()
167 xerr(err, "client.ClientFirst")
168
169 server, err := NewServer(h, []byte(clientFirst), servercs, servercs != nil)
170 xerr(err, "NewServer")
171 server.serverNonceOverride = serverNonce
172
173 serverFirst, err := server.ServerFirst(iterations, salt)
174 xerr(err, "server.ServerFirst")
175
176 clientFinal, err := client.ServerFirst([]byte(serverFirst), password)
177 xerr(err, "client.ServerFirst")
178
179 serverFinal, err := server.Finish([]byte(clientFinal), saltedPassword)
180 xerr(err, "server.Finish")
181
182 err = client.ServerFinal([]byte(serverFinal))
183 xerr(err, "client.ServerFinal")
184
185 if expErr != nil {
186 t.Fatalf("got no error, expected %v", expErr)
187 }
188 }
189
190 makeState := func(maxTLSVersion uint16) (tls.ConnectionState, tls.ConnectionState) {
191 client, server := net.Pipe()
192 defer client.Close()
193 defer server.Close()
194 tlsClient := tls.Client(client, &tls.Config{
195 InsecureSkipVerify: true,
196 MaxVersion: maxTLSVersion,
197 })
198 tlsServer := tls.Server(server, &tls.Config{
199 Certificates: []tls.Certificate{fakeCert(t, "mox.example", false)},
200 MaxVersion: maxTLSVersion,
201 })
202 errc := make(chan error, 1)
203 go func() {
204 errc <- tlsServer.Handshake()
205 }()
206 err := tlsClient.Handshake()
207 tcheck(t, err, "tls handshake")
208 err = <-errc
209 tcheck(t, err, "server tls handshake")
210 clientcs := tlsClient.ConnectionState()
211 servercs := tlsServer.ConnectionState()
212
213 return clientcs, servercs
214 }
215
216 runPlus := func(maxTLSVersion uint16, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
217 t.Helper()
218
219 // PLUS variants.
220 clientcs, servercs := makeState(maxTLSVersion)
221 runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
222 runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
223 }
224
225 run := func(expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
226 t.Helper()
227
228 // Bare variants
229 runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
230 runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
231
232 // Check with both TLS 1.2 for "tls-unique", and latest TLS for "tls-exporter".
233 runPlus(tls.VersionTLS12, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
234 runPlus(0, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
235 }
236
237 run(nil, "user", "", "pencil", 4096, "", "")
238 run(nil, "mjl@mox.example", "", "testtest", 4096, "", "")
239 run(nil, "mjl@mox.example", "", "short", 4096, "", "")
240 run(nil, "mjl@mox.example", "", "short", 2048, "", "")
241 run(nil, "mjl@mox.example", "mjl@mox.example", "testtest", 4096, "", "")
242 run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
243 run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
244 run(ErrUnsafe, "user", "", "pencil", 1, "", "") // Few iterations.
245 run(ErrUnsafe, "user", "", "pencil", 2048, "short", "") // Short client nonce.
246 run(ErrUnsafe, "user", "", "pencil", 2048, "test1234", "test") // Server added too few random data.
247
248 // Test mechanism downgrade attacks are detected.
249 runHash(sha1.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
250 runHash(sha256.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
251
252 // Test channel binding, detecting MitM attacks.
253 runChannelBind := func(maxTLSVersion uint16) {
254 t.Helper()
255
256 clientcs0, _ := makeState(maxTLSVersion)
257 _, servercs1 := makeState(maxTLSVersion)
258 runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
259 runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
260
261 // Client thinks it is on a TLS connection and server is not.
262 runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
263 runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
264 }
265
266 runChannelBind(0)
267 runChannelBind(tls.VersionTLS12)
268}
269
270// Just a cert that appears valid.
271func fakeCert(t *testing.T, name string, expired bool) tls.Certificate {
272 notAfter := time.Now()
273 if expired {
274 notAfter = notAfter.Add(-time.Hour)
275 } else {
276 notAfter = notAfter.Add(time.Hour)
277 }
278
279 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
280 template := &x509.Certificate{
281 SerialNumber: big.NewInt(1), // Required field...
282 DNSNames: []string{name},
283 NotBefore: time.Now().Add(-time.Hour),
284 NotAfter: notAfter,
285 }
286 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
287 if err != nil {
288 t.Fatalf("making certificate: %s", err)
289 }
290 cert, err := x509.ParseCertificate(localCertBuf)
291 if err != nil {
292 t.Fatalf("parsing generated certificate: %s", err)
293 }
294 c := tls.Certificate{
295 Certificate: [][]byte{localCertBuf},
296 PrivateKey: privKey,
297 Leaf: cert,
298 }
299 return c
300}
301