6 cryptorand "crypto/rand"
22 "github.com/mjl-/adns"
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
28func TestLookup(t *testing.T) {
29 mlog.SetConfig(map[string]slog.Level{"": mlog.LevelDebug})
30 log := mlog.New("mtasts", nil)
32 resolver := dns.MockResolver{
33 TXT: map[string][]string{
34 "_mta-sts.a.example.": {"v=STSv1; id=1"},
35 "_mta-sts.one.example.": {"v=STSv1; id=1", "bogus"},
36 "_mta-sts.bad.example.": {"v=STSv1; bogus"},
37 "_mta-sts.multiple.example.": {"v=STSv1; id=1", "v=STSv1; id=2"},
38 "_mta-sts.c.cnames.example.": {"v=STSv1; id=1"},
39 "_mta-sts.temperror.example.": {"v=STSv1; id=1"},
40 "_mta-sts.other.example.": {"bogus", "more"},
42 CNAME: map[string]string{
43 "_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
44 "_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
45 "_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
48 "txt _mta-sts.temperror.example.",
52 test := func(host string, expRecord *Record, expErr error) {
55 record, _, err := LookupRecord(context.Background(), log.Logger, resolver, dns.Domain{ASCII: host})
56 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
57 t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
62 if !reflect.DeepEqual(record, expRecord) {
63 t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
67 test("absent.example", nil, ErrNoRecord)
68 test("other.example", nil, ErrNoRecord)
69 test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
70 test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
71 test("bad.example", nil, ErrRecordSyntax)
72 test("multiple.example", nil, ErrMultipleRecords)
73 test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
74 test("temperror.example", nil, ErrDNS)
75 test("followtemperror.example", nil, ErrDNS)
78func TestMatches(t *testing.T) {
79 p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
81 t.Fatalf("parsing policy: %s", err)
84 mustParseDomain := func(s string) dns.Domain {
86 d, err := dns.ParseDomain(s)
88 t.Fatalf("parsing domain %q: %s", s, err)
93 match := func(s string) {
95 if !p.Matches(mustParseDomain(s)) {
96 t.Fatalf("unexpected mismatch for %q", s)
100 not := func(s string) {
102 if p.Matches(mustParseDomain(s)) {
103 t.Fatalf("unexpected match for %q", s)
108 match("sub.b.example")
110 not("sub.sub.b.example")
114type pipeListener struct {
120var _ net.Listener = &pipeListener{}
122func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
123func (l *pipeListener) Dial() (net.Conn, error) {
127 return nil, errors.New("closed")
133func (l *pipeListener) Accept() (net.Conn, error) {
140func (l *pipeListener) Close() error {
149func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
151type pipeAddr struct{}
153func (a pipeAddr) Network() string { return "pipe" }
154func (a pipeAddr) String() string { return "pipe" }
156func fakeCert(t *testing.T, expired bool) tls.Certificate {
157 notAfter := time.Now()
159 notAfter = notAfter.Add(-time.Hour)
161 notAfter = notAfter.Add(time.Hour)
164 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
166 template := &x509.Certificate{
167 SerialNumber: big.NewInt(1), // Required field...
168 DNSNames: []string{"mta-sts.mox.example"},
169 NotBefore: time.Now().Add(-time.Hour),
172 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
174 t.Fatalf("making certificate: %s", err)
176 cert, err := x509.ParseCertificate(localCertBuf)
178 t.Fatalf("parsing generated certificate: %s", err)
180 c := tls.Certificate{
181 Certificate: [][]byte{localCertBuf},
188func TestFetch(t *testing.T) {
189 log := mlog.New("mtasts", nil)
191 certok := fakeCert(t, false)
192 certbad := fakeCert(t, true)
195 HTTPClient.Transport = nil
198 resolver := dns.MockResolver{
199 TXT: map[string][]string{
200 "_mta-sts.mox.example.": {"v=STSv1; id=1"},
201 "_mta-sts.other.example.": {"v=STSv1; id=1"},
205 test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
208 pool := x509.NewCertPool()
209 pool.AddCert(cert.Leaf)
211 l := newPipeListener()
214 mux := &http.ServeMux{}
215 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
216 w.Header().Add("Location", "/other") // Ignored except for redirect.
217 w.WriteHeader(status)
218 w.Write([]byte(policyText))
222 TLSConfig: &tls.Config{
223 Certificates: []tls.Certificate{cert},
225 ErrorLog: golog.New(io.Discard, "", 0),
227 s.ServeTLS(l, "", "")
230 HTTPClient.Transport = &http.Transport{
231 Dial: func(network, addr string) (net.Conn, error) {
232 if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
233 return nil, &adns.DNSError{IsNotFound: true}
237 TLSClientConfig: &tls.Config{
242 p, _, err := FetchPolicy(context.Background(), log.Logger, dns.Domain{ASCII: domain})
243 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
244 t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
246 if err == nil && !reflect.DeepEqual(p, expPolicy) {
247 t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
250 if domain == "doesnotexist.example" {
254 _, p, _, err = Get(context.Background(), log.Logger, resolver, dns.Domain{ASCII: domain})
255 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
256 t.Fatalf("get: got err %#v, expected %#v", err, expErr)
258 if err == nil && !reflect.DeepEqual(p, expPolicy) {
259 t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
263 test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
264 test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
265 test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
266 test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
267 test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
268 test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
269 test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
270 large := make([]byte, 64*1024+2)
271 test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
272 validPolicy := "version:STSv1\nmode:none\nmax_age:1"
273 test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)