6	cryptorand "crypto/rand"
 
21	"github.com/mjl-/adns"
 
23	"github.com/mjl-/mox/dns"
 
24	"github.com/mjl-/mox/mlog"
 
27func TestLookup(t *testing.T) {
 
28	log := mlog.New("mtasts", nil)
 
30	resolver := dns.MockResolver{
 
31		TXT: map[string][]string{
 
32			"_mta-sts.a.example.":         {"v=STSv1; id=1"},
 
33			"_mta-sts.one.example.":       {"v=STSv1; id=1", "bogus"},
 
34			"_mta-sts.bad.example.":       {"v=STSv1; bogus"},
 
35			"_mta-sts.multiple.example.":  {"v=STSv1; id=1", "v=STSv1; id=2"},
 
36			"_mta-sts.c.cnames.example.":  {"v=STSv1; id=1"},
 
37			"_mta-sts.temperror.example.": {"v=STSv1; id=1"},
 
38			"_mta-sts.other.example.":     {"bogus", "more"},
 
40		CNAME: map[string]string{
 
41			"_mta-sts.a.cnames.example.":        "_mta-sts.b.cnames.example.",
 
42			"_mta-sts.b.cnames.example.":        "_mta-sts.c.cnames.example.",
 
43			"_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
 
46			"txt _mta-sts.temperror.example.",
 
50	test := func(host string, expRecord *Record, expErr error) {
 
53		record, _, err := LookupRecord(context.Background(), log.Logger, resolver, dns.Domain{ASCII: host})
 
54		if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
 
55			t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
 
60		if !reflect.DeepEqual(record, expRecord) {
 
61			t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
 
65	test("absent.example", nil, ErrNoRecord)
 
66	test("other.example", nil, ErrNoRecord)
 
67	test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
 
68	test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
 
69	test("bad.example", nil, ErrRecordSyntax)
 
70	test("multiple.example", nil, ErrMultipleRecords)
 
71	test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
 
72	test("temperror.example", nil, ErrDNS)
 
73	test("followtemperror.example", nil, ErrDNS)
 
76func TestMatches(t *testing.T) {
 
77	p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
 
79		t.Fatalf("parsing policy: %s", err)
 
82	mustParseDomain := func(s string) dns.Domain {
 
84		d, err := dns.ParseDomain(s)
 
86			t.Fatalf("parsing domain %q: %s", s, err)
 
91	match := func(s string) {
 
93		if !p.Matches(mustParseDomain(s)) {
 
94			t.Fatalf("unexpected mismatch for %q", s)
 
98	not := func(s string) {
 
100		if p.Matches(mustParseDomain(s)) {
 
101			t.Fatalf("unexpected match for %q", s)
 
106	match("sub.b.example")
 
108	not("sub.sub.b.example")
 
112type pipeListener struct {
 
118var _ net.Listener = &pipeListener{}
 
120func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
 
121func (l *pipeListener) Dial() (net.Conn, error) {
 
125		return nil, errors.New("closed")
 
131func (l *pipeListener) Accept() (net.Conn, error) {
 
138func (l *pipeListener) Close() error {
 
147func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
 
149type pipeAddr struct{}
 
151func (a pipeAddr) Network() string { return "pipe" }
 
152func (a pipeAddr) String() string  { return "pipe" }
 
154func fakeCert(t *testing.T, expired bool) tls.Certificate {
 
155	notAfter := time.Now()
 
157		notAfter = notAfter.Add(-time.Hour)
 
159		notAfter = notAfter.Add(time.Hour)
 
162	privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
 
164	template := &x509.Certificate{
 
165		SerialNumber: big.NewInt(1), // Required field...
 
166		DNSNames:     []string{"mta-sts.mox.example"},
 
167		NotBefore:    time.Now().Add(-time.Hour),
 
170	localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
 
172		t.Fatalf("making certificate: %s", err)
 
174	cert, err := x509.ParseCertificate(localCertBuf)
 
176		t.Fatalf("parsing generated certificate: %s", err)
 
178	c := tls.Certificate{
 
179		Certificate: [][]byte{localCertBuf},
 
186func TestFetch(t *testing.T) {
 
187	log := mlog.New("mtasts", nil)
 
189	certok := fakeCert(t, false)
 
190	certbad := fakeCert(t, true)
 
193		HTTPClient.Transport = nil
 
196	resolver := dns.MockResolver{
 
197		TXT: map[string][]string{
 
198			"_mta-sts.mox.example.":   {"v=STSv1; id=1"},
 
199			"_mta-sts.other.example.": {"v=STSv1; id=1"},
 
203	test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
 
206		pool := x509.NewCertPool()
 
207		pool.AddCert(cert.Leaf)
 
209		l := newPipeListener()
 
212			mux := &http.ServeMux{}
 
213			mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
 
214				w.Header().Add("Location", "/other") // Ignored except for redirect.
 
215				w.WriteHeader(status)
 
216				w.Write([]byte(policyText))
 
220				TLSConfig: &tls.Config{
 
221					Certificates: []tls.Certificate{cert},
 
223				ErrorLog: golog.New(io.Discard, "", 0),
 
225			s.ServeTLS(l, "", "")
 
228		HTTPClient.Transport = &http.Transport{
 
229			Dial: func(network, addr string) (net.Conn, error) {
 
230				if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
 
231					return nil, &adns.DNSError{IsNotFound: true}
 
235			TLSClientConfig: &tls.Config{
 
240		p, _, err := FetchPolicy(context.Background(), log.Logger, dns.Domain{ASCII: domain})
 
241		if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
 
242			t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
 
244		if err == nil && !reflect.DeepEqual(p, expPolicy) {
 
245			t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
 
248		if domain == "doesnotexist.example" {
 
252		_, p, _, err = Get(context.Background(), log.Logger, resolver, dns.Domain{ASCII: domain})
 
253		if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
 
254			t.Fatalf("get: got err %#v, expected %#v", err, expErr)
 
256		if err == nil && !reflect.DeepEqual(p, expPolicy) {
 
257			t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
 
261	test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
 
262	test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
 
263	test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
 
264	test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
 
265	test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
 
266	test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
 
267	test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
 
268	large := make([]byte, 64*1024+2)
 
269	test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
 
270	validPolicy := "version:STSv1\nmode:none\nmax_age:1"
 
271	test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)