1package mtasts
2
3import (
4 "context"
5 "crypto/ed25519"
6 cryptorand "crypto/rand"
7 "crypto/tls"
8 "crypto/x509"
9 "errors"
10 "io"
11 golog "log"
12 "math/big"
13 "net"
14 "net/http"
15 "reflect"
16 "strings"
17 "sync"
18 "testing"
19 "time"
20
21 "github.com/mjl-/adns"
22
23 "github.com/mjl-/mox/dns"
24 "github.com/mjl-/mox/mlog"
25)
26
27func TestLookup(t *testing.T) {
28 log := mlog.New("mtasts", nil)
29
30 resolver := dns.MockResolver{
31 TXT: map[string][]string{
32 "_mta-sts.a.example.": {"v=STSv1; id=1"},
33 "_mta-sts.one.example.": {"v=STSv1; id=1", "bogus"},
34 "_mta-sts.bad.example.": {"v=STSv1; bogus"},
35 "_mta-sts.multiple.example.": {"v=STSv1; id=1", "v=STSv1; id=2"},
36 "_mta-sts.c.cnames.example.": {"v=STSv1; id=1"},
37 "_mta-sts.temperror.example.": {"v=STSv1; id=1"},
38 "_mta-sts.other.example.": {"bogus", "more"},
39 },
40 CNAME: map[string]string{
41 "_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
42 "_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
43 "_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
44 },
45 Fail: []string{
46 "txt _mta-sts.temperror.example.",
47 },
48 }
49
50 test := func(host string, expRecord *Record, expErr error) {
51 t.Helper()
52
53 record, _, err := LookupRecord(context.Background(), log.Logger, resolver, dns.Domain{ASCII: host})
54 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
55 t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
56 }
57 if err != nil {
58 return
59 }
60 if !reflect.DeepEqual(record, expRecord) {
61 t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
62 }
63 }
64
65 test("absent.example", nil, ErrNoRecord)
66 test("other.example", nil, ErrNoRecord)
67 test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
68 test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
69 test("bad.example", nil, ErrRecordSyntax)
70 test("multiple.example", nil, ErrMultipleRecords)
71 test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
72 test("temperror.example", nil, ErrDNS)
73 test("followtemperror.example", nil, ErrDNS)
74}
75
76func TestMatches(t *testing.T) {
77 p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
78 if err != nil {
79 t.Fatalf("parsing policy: %s", err)
80 }
81
82 mustParseDomain := func(s string) dns.Domain {
83 t.Helper()
84 d, err := dns.ParseDomain(s)
85 if err != nil {
86 t.Fatalf("parsing domain %q: %s", s, err)
87 }
88 return d
89 }
90
91 match := func(s string) {
92 t.Helper()
93 if !p.Matches(mustParseDomain(s)) {
94 t.Fatalf("unexpected mismatch for %q", s)
95 }
96 }
97
98 not := func(s string) {
99 t.Helper()
100 if p.Matches(mustParseDomain(s)) {
101 t.Fatalf("unexpected match for %q", s)
102 }
103 }
104
105 match("a.example")
106 match("sub.b.example")
107 not("b.example")
108 not("sub.sub.b.example")
109 not("other")
110}
111
112type pipeListener struct {
113 sync.Mutex
114 closed bool
115 C chan net.Conn
116}
117
118var _ net.Listener = &pipeListener{}
119
120func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
121func (l *pipeListener) Dial() (net.Conn, error) {
122 l.Lock()
123 defer l.Unlock()
124 if l.closed {
125 return nil, errors.New("closed")
126 }
127 c, s := net.Pipe()
128 l.C <- s
129 return c, nil
130}
131func (l *pipeListener) Accept() (net.Conn, error) {
132 conn := <-l.C
133 if conn == nil {
134 return nil, io.EOF
135 }
136 return conn, nil
137}
138func (l *pipeListener) Close() error {
139 l.Lock()
140 defer l.Unlock()
141 if !l.closed {
142 l.closed = true
143 close(l.C)
144 }
145 return nil
146}
147func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
148
149type pipeAddr struct{}
150
151func (a pipeAddr) Network() string { return "pipe" }
152func (a pipeAddr) String() string { return "pipe" }
153
154func fakeCert(t *testing.T, expired bool) tls.Certificate {
155 notAfter := time.Now()
156 if expired {
157 notAfter = notAfter.Add(-time.Hour)
158 } else {
159 notAfter = notAfter.Add(time.Hour)
160 }
161
162 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
163
164 template := &x509.Certificate{
165 SerialNumber: big.NewInt(1), // Required field...
166 DNSNames: []string{"mta-sts.mox.example"},
167 NotBefore: time.Now().Add(-time.Hour),
168 NotAfter: notAfter,
169 }
170 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
171 if err != nil {
172 t.Fatalf("making certificate: %s", err)
173 }
174 cert, err := x509.ParseCertificate(localCertBuf)
175 if err != nil {
176 t.Fatalf("parsing generated certificate: %s", err)
177 }
178 c := tls.Certificate{
179 Certificate: [][]byte{localCertBuf},
180 PrivateKey: privKey,
181 Leaf: cert,
182 }
183 return c
184}
185
186func TestFetch(t *testing.T) {
187 log := mlog.New("mtasts", nil)
188
189 certok := fakeCert(t, false)
190 certbad := fakeCert(t, true)
191
192 defer func() {
193 HTTPClient.Transport = nil
194 }()
195
196 resolver := dns.MockResolver{
197 TXT: map[string][]string{
198 "_mta-sts.mox.example.": {"v=STSv1; id=1"},
199 "_mta-sts.other.example.": {"v=STSv1; id=1"},
200 },
201 }
202
203 test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
204 t.Helper()
205
206 pool := x509.NewCertPool()
207 pool.AddCert(cert.Leaf)
208
209 l := newPipeListener()
210 defer l.Close()
211 go func() {
212 mux := &http.ServeMux{}
213 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
214 w.Header().Add("Location", "/other") // Ignored except for redirect.
215 w.WriteHeader(status)
216 w.Write([]byte(policyText))
217 })
218 s := &http.Server{
219 Handler: mux,
220 TLSConfig: &tls.Config{
221 Certificates: []tls.Certificate{cert},
222 },
223 ErrorLog: golog.New(io.Discard, "", 0),
224 }
225 s.ServeTLS(l, "", "")
226 }()
227
228 HTTPClient.Transport = &http.Transport{
229 Dial: func(network, addr string) (net.Conn, error) {
230 if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
231 return nil, &adns.DNSError{IsNotFound: true}
232 }
233 return l.Dial()
234 },
235 TLSClientConfig: &tls.Config{
236 RootCAs: pool,
237 },
238 }
239
240 p, _, err := FetchPolicy(context.Background(), log.Logger, dns.Domain{ASCII: domain})
241 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
242 t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
243 }
244 if err == nil && !reflect.DeepEqual(p, expPolicy) {
245 t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
246 }
247
248 if domain == "doesnotexist.example" {
249 expErr = ErrNoRecord
250 }
251
252 _, p, _, err = Get(context.Background(), log.Logger, resolver, dns.Domain{ASCII: domain})
253 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
254 t.Fatalf("get: got err %#v, expected %#v", err, expErr)
255 }
256 if err == nil && !reflect.DeepEqual(p, expPolicy) {
257 t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
258 }
259 }
260
261 test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
262 test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
263 test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
264 test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
265 test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
266 test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
267 test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
268 large := make([]byte, 64*1024+2)
269 test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
270 validPolicy := "version:STSv1\nmode:none\nmax_age:1"
271 test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)
272}
273