1package mtastsdb
2
3import (
4 "context"
5 "crypto/ed25519"
6 cryptorand "crypto/rand"
7 "crypto/tls"
8 "crypto/x509"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "math/big"
14 "net"
15 "net/http"
16 "os"
17 "path/filepath"
18 "sync"
19 "testing"
20 "time"
21
22 "github.com/mjl-/bstore"
23
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
26 "github.com/mjl-/mox/mox-"
27 "github.com/mjl-/mox/mtasts"
28)
29
30var ctxbg = context.Background()
31
32func TestRefresh(t *testing.T) {
33 mox.Shutdown = ctxbg
34 mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
35 mox.Conf.Static.DataDir = "."
36
37 dbpath := mox.DataDirPath("mtasts.db")
38 os.MkdirAll(filepath.Dir(dbpath), 0770)
39 os.Remove(dbpath)
40 defer os.Remove(dbpath)
41
42 if err := Init(false); err != nil {
43 t.Fatalf("init database: %s", err)
44 }
45 defer Close()
46
47 db, err := database(ctxbg)
48 if err != nil {
49 t.Fatalf("database: %s", err)
50 }
51
52 cert := fakeCert(t, false)
53 defer func() {
54 mtasts.HTTPClient.Transport = nil
55 }()
56
57 insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
58 t.Helper()
59
60 mxd, err := dns.ParseDomain(mx)
61 if err != nil {
62 t.Fatalf("parsing mx domain %q: %s", mx, err)
63 }
64 policy := mtasts.Policy{
65 Version: "STSv1",
66 Mode: mode,
67 MX: []mtasts.MX{{Wildcard: false, Domain: mxd}},
68 MaxAgeSeconds: maxAge,
69 Extensions: nil,
70 }
71
72 pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()}
73 if err := db.Insert(ctxbg, &pr); err != nil {
74 t.Fatalf("insert policy: %s", err)
75 }
76 }
77
78 now := time.Now()
79 // Updated just now.
80 insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
81 // To be removed.
82 insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
83 // To be refreshed, same id.
84 insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
85 // To be refreshed and succeed.
86 insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
87 // To be refreshed and fail to fetch.
88 insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
89
90 resolver := dns.MockResolver{
91 TXT: map[string][]string{
92 "_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
93 "_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
94 "_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
95 },
96 }
97
98 pool := x509.NewCertPool()
99 pool.AddCert(cert.Leaf)
100
101 l := newPipeListener()
102 defer l.Close()
103 go func() {
104 mux := &http.ServeMux{}
105 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
106 if r.Host == "mta-sts.policybad.mox.example" {
107 w.WriteHeader(500)
108 return
109 }
110 fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
111 })
112 s := &http.Server{
113 Handler: mux,
114 TLSConfig: &tls.Config{
115 Certificates: []tls.Certificate{cert},
116 },
117 ErrorLog: log.New(io.Discard, "", 0),
118 }
119 s.ServeTLS(l, "", "")
120 }()
121
122 mtasts.HTTPClient.Transport = &http.Transport{
123 Dial: func(network, addr string) (net.Conn, error) {
124 return l.Dial()
125 },
126 TLSClientConfig: &tls.Config{
127 RootCAs: pool,
128 },
129 }
130
131 slept := 0
132 sleep := func(d time.Duration) {
133 slept++
134 interval := 3 * time.Hour / 2
135 if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
136 t.Fatalf("bad sleep duration %v", d)
137 }
138 }
139 log := mlog.New("mtastsdb", nil)
140 if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 {
141 t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
142 }
143 if slept != 2 {
144 t.Fatalf("bad sleeps, %d instead of 2", slept)
145 }
146 time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
147
148 // Should not do any more refreshes and return immediately.
149 q := bstore.QueryDB[PolicyRecord](ctxbg, db)
150 q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
151 if _, err := q.Delete(); err != nil {
152 t.Fatalf("delete record that would be refreshed: %v", err)
153 }
154 mox.Context = ctxbg
155 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
156 mox.ShutdownCancel()
157 n := refresh()
158 if n != 0 {
159 t.Fatalf("refresh found unexpected work, n %d", n)
160 }
161 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
162}
163
164type pipeListener struct {
165 sync.Mutex
166 closed bool
167 C chan net.Conn
168}
169
170var _ net.Listener = &pipeListener{}
171
172func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
173func (l *pipeListener) Dial() (net.Conn, error) {
174 l.Lock()
175 defer l.Unlock()
176 if l.closed {
177 return nil, errors.New("closed")
178 }
179 c, s := net.Pipe()
180 l.C <- s
181 return c, nil
182}
183func (l *pipeListener) Accept() (net.Conn, error) {
184 conn := <-l.C
185 if conn == nil {
186 return nil, io.EOF
187 }
188 return conn, nil
189}
190func (l *pipeListener) Close() error {
191 l.Lock()
192 defer l.Unlock()
193 if !l.closed {
194 l.closed = true
195 close(l.C)
196 }
197 return nil
198}
199func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
200
201type pipeAddr struct{}
202
203func (a pipeAddr) Network() string { return "pipe" }
204func (a pipeAddr) String() string { return "pipe" }
205
206func fakeCert(t *testing.T, expired bool) tls.Certificate {
207 notAfter := time.Now()
208 if expired {
209 notAfter = notAfter.Add(-time.Hour)
210 } else {
211 notAfter = notAfter.Add(time.Hour)
212 }
213
214 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
215
216 template := &x509.Certificate{
217 SerialNumber: big.NewInt(1), // Required field...
218 DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
219 NotBefore: time.Now().Add(-time.Hour),
220 NotAfter: notAfter,
221 }
222 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
223 if err != nil {
224 t.Fatalf("making certificate: %s", err)
225 }
226 cert, err := x509.ParseCertificate(localCertBuf)
227 if err != nil {
228 t.Fatalf("parsing generated certificate: %s", err)
229 }
230 c := tls.Certificate{
231 Certificate: [][]byte{localCertBuf},
232 PrivateKey: privKey,
233 Leaf: cert,
234 }
235 return c
236}
237