6 cryptorand "crypto/rand"
21 "github.com/mjl-/adns"
23 "github.com/mjl-/mox/dns"
24 "github.com/mjl-/mox/mlog"
27func TestLookup(t *testing.T) {
28 log := mlog.New("mtasts", nil)
30 resolver := dns.MockResolver{
31 TXT: map[string][]string{
32 "_mta-sts.a.example.": {"v=STSv1; id=1"},
33 "_mta-sts.one.example.": {"v=STSv1; id=1", "bogus"},
34 "_mta-sts.bad.example.": {"v=STSv1; bogus"},
35 "_mta-sts.multiple.example.": {"v=STSv1; id=1", "v=STSv1; id=2"},
36 "_mta-sts.c.cnames.example.": {"v=STSv1; id=1"},
37 "_mta-sts.temperror.example.": {"v=STSv1; id=1"},
38 "_mta-sts.other.example.": {"bogus", "more"},
40 CNAME: map[string]string{
41 "_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
42 "_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
43 "_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
46 "txt _mta-sts.temperror.example.",
50 test := func(host string, expRecord *Record, expErr error) {
53 record, _, err := LookupRecord(context.Background(), log.Logger, resolver, dns.Domain{ASCII: host})
54 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
55 t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
60 if !reflect.DeepEqual(record, expRecord) {
61 t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
65 test("absent.example", nil, ErrNoRecord)
66 test("other.example", nil, ErrNoRecord)
67 test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
68 test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
69 test("bad.example", nil, ErrRecordSyntax)
70 test("multiple.example", nil, ErrMultipleRecords)
71 test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
72 test("temperror.example", nil, ErrDNS)
73 test("followtemperror.example", nil, ErrDNS)
76func TestMatches(t *testing.T) {
77 p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
79 t.Fatalf("parsing policy: %s", err)
82 mustParseDomain := func(s string) dns.Domain {
84 d, err := dns.ParseDomain(s)
86 t.Fatalf("parsing domain %q: %s", s, err)
91 match := func(s string) {
93 if !p.Matches(mustParseDomain(s)) {
94 t.Fatalf("unexpected mismatch for %q", s)
98 not := func(s string) {
100 if p.Matches(mustParseDomain(s)) {
101 t.Fatalf("unexpected match for %q", s)
106 match("sub.b.example")
108 not("sub.sub.b.example")
112type pipeListener struct {
118var _ net.Listener = &pipeListener{}
120func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
121func (l *pipeListener) Dial() (net.Conn, error) {
125 return nil, errors.New("closed")
131func (l *pipeListener) Accept() (net.Conn, error) {
138func (l *pipeListener) Close() error {
147func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
149type pipeAddr struct{}
151func (a pipeAddr) Network() string { return "pipe" }
152func (a pipeAddr) String() string { return "pipe" }
154func fakeCert(t *testing.T, expired bool) tls.Certificate {
155 notAfter := time.Now()
157 notAfter = notAfter.Add(-time.Hour)
159 notAfter = notAfter.Add(time.Hour)
162 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
164 template := &x509.Certificate{
165 SerialNumber: big.NewInt(1), // Required field...
166 DNSNames: []string{"mta-sts.mox.example"},
167 NotBefore: time.Now().Add(-time.Hour),
170 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
172 t.Fatalf("making certificate: %s", err)
174 cert, err := x509.ParseCertificate(localCertBuf)
176 t.Fatalf("parsing generated certificate: %s", err)
178 c := tls.Certificate{
179 Certificate: [][]byte{localCertBuf},
186func TestFetch(t *testing.T) {
187 log := mlog.New("mtasts", nil)
189 certok := fakeCert(t, false)
190 certbad := fakeCert(t, true)
193 HTTPClient.Transport = nil
196 resolver := dns.MockResolver{
197 TXT: map[string][]string{
198 "_mta-sts.mox.example.": {"v=STSv1; id=1"},
199 "_mta-sts.other.example.": {"v=STSv1; id=1"},
203 test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
206 pool := x509.NewCertPool()
207 pool.AddCert(cert.Leaf)
209 l := newPipeListener()
212 mux := &http.ServeMux{}
213 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
214 w.Header().Add("Location", "/other") // Ignored except for redirect.
215 w.WriteHeader(status)
216 w.Write([]byte(policyText))
220 TLSConfig: &tls.Config{
221 Certificates: []tls.Certificate{cert},
223 ErrorLog: golog.New(io.Discard, "", 0),
225 s.ServeTLS(l, "", "")
228 HTTPClient.Transport = &http.Transport{
229 Dial: func(network, addr string) (net.Conn, error) {
230 if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
231 return nil, &adns.DNSError{IsNotFound: true}
235 TLSClientConfig: &tls.Config{
240 p, _, err := FetchPolicy(context.Background(), log.Logger, dns.Domain{ASCII: domain})
241 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
242 t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
244 if err == nil && !reflect.DeepEqual(p, expPolicy) {
245 t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
248 if domain == "doesnotexist.example" {
252 _, p, _, err = Get(context.Background(), log.Logger, resolver, dns.Domain{ASCII: domain})
253 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
254 t.Fatalf("get: got err %#v, expected %#v", err, expErr)
256 if err == nil && !reflect.DeepEqual(p, expPolicy) {
257 t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
261 test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
262 test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
263 test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
264 test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
265 test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
266 test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
267 test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
268 large := make([]byte, 64*1024+2)
269 test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
270 validPolicy := "version:STSv1\nmode:none\nmax_age:1"
271 test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)