6 cryptorand "crypto/rand"
22 "github.com/mjl-/bstore"
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
26 "github.com/mjl-/mox/mox-"
27 "github.com/mjl-/mox/mtasts"
30var ctxbg = context.Background()
32func TestRefresh(t *testing.T) {
34 mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
35 mox.Conf.Static.DataDir = "."
37 dbpath := mox.DataDirPath("mtasts.db")
38 os.MkdirAll(filepath.Dir(dbpath), 0770)
40 defer os.Remove(dbpath)
42 if err := Init(false); err != nil {
43 t.Fatalf("init database: %s", err)
47 db, err := database(ctxbg)
49 t.Fatalf("database: %s", err)
52 cert := fakeCert(t, false)
54 mtasts.HTTPClient.Transport = nil
57 insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
60 mxd, err := dns.ParseDomain(mx)
62 t.Fatalf("parsing mx domain %q: %s", mx, err)
64 policy := mtasts.Policy{
67 MX: []mtasts.MX{{Wildcard: false, Domain: mxd}},
68 MaxAgeSeconds: maxAge,
72 pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()}
73 if err := db.Insert(ctxbg, &pr); err != nil {
74 t.Fatalf("insert policy: %s", err)
80 insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
82 insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
83 // To be refreshed, same id.
84 insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
85 // To be refreshed and succeed.
86 insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
87 // To be refreshed and fail to fetch.
88 insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
90 resolver := dns.MockResolver{
91 TXT: map[string][]string{
92 "_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
93 "_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
94 "_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
98 pool := x509.NewCertPool()
99 pool.AddCert(cert.Leaf)
101 l := newPipeListener()
104 mux := &http.ServeMux{}
105 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
106 if r.Host == "mta-sts.policybad.mox.example" {
110 fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
114 TLSConfig: &tls.Config{
115 Certificates: []tls.Certificate{cert},
117 ErrorLog: log.New(io.Discard, "", 0),
119 s.ServeTLS(l, "", "")
122 mtasts.HTTPClient.Transport = &http.Transport{
123 Dial: func(network, addr string) (net.Conn, error) {
126 TLSClientConfig: &tls.Config{
132 sleep := func(d time.Duration) {
134 interval := 3 * time.Hour / 2
135 if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
136 t.Fatalf("bad sleep duration %v", d)
139 log := mlog.New("mtastsdb", nil)
140 if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 {
141 t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
144 t.Fatalf("bad sleeps, %d instead of 2", slept)
146 time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
148 // Should not do any more refreshes and return immediately.
149 q := bstore.QueryDB[PolicyRecord](ctxbg, db)
150 q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
151 if _, err := q.Delete(); err != nil {
152 t.Fatalf("delete record that would be refreshed: %v", err)
155 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
159 t.Fatalf("refresh found unexpected work, n %d", n)
161 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
164type pipeListener struct {
170var _ net.Listener = &pipeListener{}
172func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
173func (l *pipeListener) Dial() (net.Conn, error) {
177 return nil, errors.New("closed")
183func (l *pipeListener) Accept() (net.Conn, error) {
190func (l *pipeListener) Close() error {
199func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
201type pipeAddr struct{}
203func (a pipeAddr) Network() string { return "pipe" }
204func (a pipeAddr) String() string { return "pipe" }
206func fakeCert(t *testing.T, expired bool) tls.Certificate {
207 notAfter := time.Now()
209 notAfter = notAfter.Add(-time.Hour)
211 notAfter = notAfter.Add(time.Hour)
214 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
216 template := &x509.Certificate{
217 SerialNumber: big.NewInt(1), // Required field...
218 DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
219 NotBefore: time.Now().Add(-time.Hour),
222 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
224 t.Fatalf("making certificate: %s", err)
226 cert, err := x509.ParseCertificate(localCertBuf)
228 t.Fatalf("parsing generated certificate: %s", err)
230 c := tls.Certificate{
231 Certificate: [][]byte{localCertBuf},