1package mtastsdb
2
3import (
4 "context"
5 "crypto/ed25519"
6 cryptorand "crypto/rand"
7 "crypto/tls"
8 "crypto/x509"
9 "errors"
10 "fmt"
11 "io"
12 golog "log"
13 "math/big"
14 "net"
15 "net/http"
16 "os"
17 "path/filepath"
18 "sync"
19 "testing"
20 "time"
21
22 "github.com/mjl-/bstore"
23
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
26 "github.com/mjl-/mox/mox-"
27 "github.com/mjl-/mox/mtasts"
28)
29
30var ctxbg = context.Background()
31
32func TestRefresh(t *testing.T) {
33 mox.Shutdown = ctxbg
34 mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
35 mox.Conf.Static.DataDir = "."
36
37 dbpath := mox.DataDirPath("mtasts.db")
38 os.MkdirAll(filepath.Dir(dbpath), 0770)
39 os.Remove(dbpath)
40 defer os.Remove(dbpath)
41
42 log := mlog.New("mtastsdb", nil)
43
44 err := Init(false)
45 tcheckf(t, err, "init database")
46 defer func() {
47 err := Close()
48 tcheckf(t, err, "close database")
49 }()
50
51 cert := fakeCert(t, false)
52 defer func() {
53 mtasts.HTTPClient.Transport = nil
54 }()
55
56 insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
57 t.Helper()
58
59 mxd, err := dns.ParseDomain(mx)
60 if err != nil {
61 t.Fatalf("parsing mx domain %q: %s", mx, err)
62 }
63 policy := mtasts.Policy{
64 Version: "STSv1",
65 Mode: mode,
66 MX: []mtasts.MX{{Wildcard: false, Domain: mxd}},
67 MaxAgeSeconds: maxAge,
68 Extensions: nil,
69 }
70
71 pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()}
72 if err := DB.Insert(ctxbg, &pr); err != nil {
73 t.Fatalf("insert policy: %s", err)
74 }
75 }
76
77 now := time.Now()
78 // Updated just now.
79 insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
80 // To be removed.
81 insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
82 // To be refreshed, same id.
83 insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
84 // To be refreshed and succeed.
85 insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
86 // To be refreshed and fail to fetch.
87 insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
88
89 resolver := dns.MockResolver{
90 TXT: map[string][]string{
91 "_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
92 "_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
93 "_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
94 },
95 }
96
97 pool := x509.NewCertPool()
98 pool.AddCert(cert.Leaf)
99
100 l := newPipeListener()
101 defer l.Close()
102 go func() {
103 mux := &http.ServeMux{}
104 mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
105 if r.Host == "mta-sts.policybad.mox.example" {
106 w.WriteHeader(500)
107 return
108 }
109 fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
110 })
111 s := &http.Server{
112 Handler: mux,
113 TLSConfig: &tls.Config{
114 Certificates: []tls.Certificate{cert},
115 },
116 ErrorLog: golog.New(io.Discard, "", 0),
117 }
118 s.ServeTLS(l, "", "")
119 }()
120
121 mtasts.HTTPClient.Transport = &http.Transport{
122 Dial: func(network, addr string) (net.Conn, error) {
123 return l.Dial()
124 },
125 TLSClientConfig: &tls.Config{
126 RootCAs: pool,
127 },
128 }
129
130 slept := 0
131 sleep := func(d time.Duration) {
132 slept++
133 interval := 3 * time.Hour / 2
134 if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
135 t.Fatalf("bad sleep duration %v", d)
136 }
137 }
138 if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 {
139 t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
140 }
141 if slept != 2 {
142 t.Fatalf("bad sleeps, %d instead of 2", slept)
143 }
144 time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
145
146 // Should not do any more refreshes and return immediately.
147 q := bstore.QueryDB[PolicyRecord](ctxbg, DB)
148 q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
149 if _, err := q.Delete(); err != nil {
150 t.Fatalf("delete record that would be refreshed: %v", err)
151 }
152 mox.Context = ctxbg
153 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
154 mox.ShutdownCancel()
155 n := refresh()
156 if n != 0 {
157 t.Fatalf("refresh found unexpected work, n %d", n)
158 }
159 mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
160}
161
162type pipeListener struct {
163 sync.Mutex
164 closed bool
165 C chan net.Conn
166}
167
168var _ net.Listener = &pipeListener{}
169
170func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
171func (l *pipeListener) Dial() (net.Conn, error) {
172 l.Lock()
173 defer l.Unlock()
174 if l.closed {
175 return nil, errors.New("closed")
176 }
177 c, s := net.Pipe()
178 l.C <- s
179 return c, nil
180}
181func (l *pipeListener) Accept() (net.Conn, error) {
182 conn := <-l.C
183 if conn == nil {
184 return nil, io.EOF
185 }
186 return conn, nil
187}
188func (l *pipeListener) Close() error {
189 l.Lock()
190 defer l.Unlock()
191 if !l.closed {
192 l.closed = true
193 close(l.C)
194 }
195 return nil
196}
197func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
198
199type pipeAddr struct{}
200
201func (a pipeAddr) Network() string { return "pipe" }
202func (a pipeAddr) String() string { return "pipe" }
203
204func fakeCert(t *testing.T, expired bool) tls.Certificate {
205 notAfter := time.Now()
206 if expired {
207 notAfter = notAfter.Add(-time.Hour)
208 } else {
209 notAfter = notAfter.Add(time.Hour)
210 }
211
212 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
213
214 template := &x509.Certificate{
215 SerialNumber: big.NewInt(1), // Required field...
216 DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
217 NotBefore: time.Now().Add(-time.Hour),
218 NotAfter: notAfter,
219 }
220 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
221 if err != nil {
222 t.Fatalf("making certificate: %s", err)
223 }
224 cert, err := x509.ParseCertificate(localCertBuf)
225 if err != nil {
226 t.Fatalf("parsing generated certificate: %s", err)
227 }
228 c := tls.Certificate{
229 Certificate: [][]byte{localCertBuf},
230 PrivateKey: privKey,
231 Leaf: cert,
232 }
233 return c
234}
235