6 cryptorand "crypto/rand"
22 "github.com/mjl-/bstore"
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
26 "github.com/mjl-/mox/mox-"
27 "github.com/mjl-/mox/mtasts"
30var ctxbg = context.Background()
32func TestRefresh(t *testing.T) {
34 mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
35 mox.Conf.Static.DataDir = "."
37 dbpath := mox.DataDirPath("mtasts.db")
38 os.MkdirAll(filepath.Dir(dbpath), 0770)
40 defer os.Remove(dbpath)
42 log := mlog.New("mtastsdb", nil)
45 tcheckf(t, err, "init database")
48 tcheckf(t, err, "close database")
51 cert := fakeCert(t, false)
53 mtasts.HTTPClient.Transport = nil
56 insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
59 mxd, err := dns.ParseDomain(mx)
61 t.Fatalf("parsing mx domain %q: %s", mx, err)
63 policy := mtasts.Policy{
66 MX: []mtasts.MX{{Wildcard: false, Domain: mxd}},
67 MaxAgeSeconds: maxAge,
71 pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()}
72 if err := DB.Insert(ctxbg, &pr); err != nil {
73 t.Fatalf("insert policy: %s", err)
79 insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
81 insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
82 // To be refreshed, same id.
83 insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
84 // To be refreshed and succeed.
85 insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
86 // To be refreshed and fail to fetch.
87 insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
89 resolver := dns.MockResolver{
90 TXT: map[string][]string{
91 "_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
92 "_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
93 "_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
97 pool := x509.NewCertPool()
98 pool.AddCert(cert.Leaf)
100 l := newPipeListener()
103 mux := &http.ServeMux{}
104 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
105 if r.Host == "mta-sts.policybad.mox.example" {
109 fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
113 TLSConfig: &tls.Config{
114 Certificates: []tls.Certificate{cert},
116 ErrorLog: golog.New(io.Discard, "", 0),
118 s.ServeTLS(l, "", "")
121 mtasts.HTTPClient.Transport = &http.Transport{
122 Dial: func(network, addr string) (net.Conn, error) {
125 TLSClientConfig: &tls.Config{
131 sleep := func(d time.Duration) {
133 interval := 3 * time.Hour / 2
134 if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
135 t.Fatalf("bad sleep duration %v", d)
138 if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 {
139 t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
142 t.Fatalf("bad sleeps, %d instead of 2", slept)
144 time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
146 // Should not do any more refreshes and return immediately.
147 q := bstore.QueryDB[PolicyRecord](ctxbg, DB)
148 q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
149 if _, err := q.Delete(); err != nil {
150 t.Fatalf("delete record that would be refreshed: %v", err)
153 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
157 t.Fatalf("refresh found unexpected work, n %d", n)
159 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
162type pipeListener struct {
168var _ net.Listener = &pipeListener{}
170func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
171func (l *pipeListener) Dial() (net.Conn, error) {
175 return nil, errors.New("closed")
181func (l *pipeListener) Accept() (net.Conn, error) {
188func (l *pipeListener) Close() error {
197func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
199type pipeAddr struct{}
201func (a pipeAddr) Network() string { return "pipe" }
202func (a pipeAddr) String() string { return "pipe" }
204func fakeCert(t *testing.T, expired bool) tls.Certificate {
205 notAfter := time.Now()
207 notAfter = notAfter.Add(-time.Hour)
209 notAfter = notAfter.Add(time.Hour)
212 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
214 template := &x509.Certificate{
215 SerialNumber: big.NewInt(1), // Required field...
216 DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
217 NotBefore: time.Now().Add(-time.Hour),
220 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
222 t.Fatalf("making certificate: %s", err)
224 cert, err := x509.ParseCertificate(localCertBuf)
226 t.Fatalf("parsing generated certificate: %s", err)
228 c := tls.Certificate{
229 Certificate: [][]byte{localCertBuf},