1package mtasts
2
3import (
4 "context"
5 "crypto/ed25519"
6 cryptorand "crypto/rand"
7 "crypto/tls"
8 "crypto/x509"
9 "errors"
10 "io"
11 golog "log"
12 "log/slog"
13 "math/big"
14 "net"
15 "net/http"
16 "reflect"
17 "strings"
18 "sync"
19 "testing"
20 "time"
21
22 "github.com/mjl-/adns"
23
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
26)
27
28func TestLookup(t *testing.T) {
29 mlog.SetConfig(map[string]slog.Level{"": mlog.LevelDebug})
30 log := mlog.New("mtasts", nil)
31
32 resolver := dns.MockResolver{
33 TXT: map[string][]string{
34 "_mta-sts.a.example.": {"v=STSv1; id=1"},
35 "_mta-sts.one.example.": {"v=STSv1; id=1", "bogus"},
36 "_mta-sts.bad.example.": {"v=STSv1; bogus"},
37 "_mta-sts.multiple.example.": {"v=STSv1; id=1", "v=STSv1; id=2"},
38 "_mta-sts.c.cnames.example.": {"v=STSv1; id=1"},
39 "_mta-sts.temperror.example.": {"v=STSv1; id=1"},
40 "_mta-sts.other.example.": {"bogus", "more"},
41 },
42 CNAME: map[string]string{
43 "_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
44 "_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
45 "_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
46 },
47 Fail: []string{
48 "txt _mta-sts.temperror.example.",
49 },
50 }
51
52 test := func(host string, expRecord *Record, expErr error) {
53 t.Helper()
54
55 record, _, err := LookupRecord(context.Background(), log.Logger, resolver, dns.Domain{ASCII: host})
56 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
57 t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
58 }
59 if err != nil {
60 return
61 }
62 if !reflect.DeepEqual(record, expRecord) {
63 t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
64 }
65 }
66
67 test("absent.example", nil, ErrNoRecord)
68 test("other.example", nil, ErrNoRecord)
69 test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
70 test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
71 test("bad.example", nil, ErrRecordSyntax)
72 test("multiple.example", nil, ErrMultipleRecords)
73 test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
74 test("temperror.example", nil, ErrDNS)
75 test("followtemperror.example", nil, ErrDNS)
76}
77
78func TestMatches(t *testing.T) {
79 p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
80 if err != nil {
81 t.Fatalf("parsing policy: %s", err)
82 }
83
84 mustParseDomain := func(s string) dns.Domain {
85 t.Helper()
86 d, err := dns.ParseDomain(s)
87 if err != nil {
88 t.Fatalf("parsing domain %q: %s", s, err)
89 }
90 return d
91 }
92
93 match := func(s string) {
94 t.Helper()
95 if !p.Matches(mustParseDomain(s)) {
96 t.Fatalf("unexpected mismatch for %q", s)
97 }
98 }
99
100 not := func(s string) {
101 t.Helper()
102 if p.Matches(mustParseDomain(s)) {
103 t.Fatalf("unexpected match for %q", s)
104 }
105 }
106
107 match("a.example")
108 match("sub.b.example")
109 not("b.example")
110 not("sub.sub.b.example")
111 not("other")
112}
113
114type pipeListener struct {
115 sync.Mutex
116 closed bool
117 C chan net.Conn
118}
119
120var _ net.Listener = &pipeListener{}
121
122func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
123func (l *pipeListener) Dial() (net.Conn, error) {
124 l.Lock()
125 defer l.Unlock()
126 if l.closed {
127 return nil, errors.New("closed")
128 }
129 c, s := net.Pipe()
130 l.C <- s
131 return c, nil
132}
133func (l *pipeListener) Accept() (net.Conn, error) {
134 conn := <-l.C
135 if conn == nil {
136 return nil, io.EOF
137 }
138 return conn, nil
139}
140func (l *pipeListener) Close() error {
141 l.Lock()
142 defer l.Unlock()
143 if !l.closed {
144 l.closed = true
145 close(l.C)
146 }
147 return nil
148}
149func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
150
151type pipeAddr struct{}
152
153func (a pipeAddr) Network() string { return "pipe" }
154func (a pipeAddr) String() string { return "pipe" }
155
156func fakeCert(t *testing.T, expired bool) tls.Certificate {
157 notAfter := time.Now()
158 if expired {
159 notAfter = notAfter.Add(-time.Hour)
160 } else {
161 notAfter = notAfter.Add(time.Hour)
162 }
163
164 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
165
166 template := &x509.Certificate{
167 SerialNumber: big.NewInt(1), // Required field...
168 DNSNames: []string{"mta-sts.mox.example"},
169 NotBefore: time.Now().Add(-time.Hour),
170 NotAfter: notAfter,
171 }
172 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
173 if err != nil {
174 t.Fatalf("making certificate: %s", err)
175 }
176 cert, err := x509.ParseCertificate(localCertBuf)
177 if err != nil {
178 t.Fatalf("parsing generated certificate: %s", err)
179 }
180 c := tls.Certificate{
181 Certificate: [][]byte{localCertBuf},
182 PrivateKey: privKey,
183 Leaf: cert,
184 }
185 return c
186}
187
188func TestFetch(t *testing.T) {
189 log := mlog.New("mtasts", nil)
190
191 certok := fakeCert(t, false)
192 certbad := fakeCert(t, true)
193
194 defer func() {
195 HTTPClient.Transport = nil
196 }()
197
198 resolver := dns.MockResolver{
199 TXT: map[string][]string{
200 "_mta-sts.mox.example.": {"v=STSv1; id=1"},
201 "_mta-sts.other.example.": {"v=STSv1; id=1"},
202 },
203 }
204
205 test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
206 t.Helper()
207
208 pool := x509.NewCertPool()
209 pool.AddCert(cert.Leaf)
210
211 l := newPipeListener()
212 defer l.Close()
213 go func() {
214 mux := &http.ServeMux{}
215 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
216 w.Header().Add("Location", "/other") // Ignored except for redirect.
217 w.WriteHeader(status)
218 w.Write([]byte(policyText))
219 })
220 s := &http.Server{
221 Handler: mux,
222 TLSConfig: &tls.Config{
223 Certificates: []tls.Certificate{cert},
224 },
225 ErrorLog: golog.New(io.Discard, "", 0),
226 }
227 s.ServeTLS(l, "", "")
228 }()
229
230 HTTPClient.Transport = &http.Transport{
231 Dial: func(network, addr string) (net.Conn, error) {
232 if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
233 return nil, &adns.DNSError{IsNotFound: true}
234 }
235 return l.Dial()
236 },
237 TLSClientConfig: &tls.Config{
238 RootCAs: pool,
239 },
240 }
241
242 p, _, err := FetchPolicy(context.Background(), log.Logger, dns.Domain{ASCII: domain})
243 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
244 t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
245 }
246 if err == nil && !reflect.DeepEqual(p, expPolicy) {
247 t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
248 }
249
250 if domain == "doesnotexist.example" {
251 expErr = ErrNoRecord
252 }
253
254 _, p, _, err = Get(context.Background(), log.Logger, resolver, dns.Domain{ASCII: domain})
255 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
256 t.Fatalf("get: got err %#v, expected %#v", err, expErr)
257 }
258 if err == nil && !reflect.DeepEqual(p, expPolicy) {
259 t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
260 }
261 }
262
263 test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
264 test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
265 test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
266 test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
267 test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
268 test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
269 test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
270 large := make([]byte, 64*1024+2)
271 test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
272 validPolicy := "version:STSv1\nmode:none\nmax_age:1"
273 test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)
274}
275