1package mtastsdb
2
3import (
4 "errors"
5 "fmt"
6 "net"
7 "net/http"
8 "os"
9 "path/filepath"
10 "reflect"
11 "testing"
12 "time"
13
14 "github.com/mjl-/mox/dns"
15 "github.com/mjl-/mox/mlog"
16 "github.com/mjl-/mox/mox-"
17 "github.com/mjl-/mox/mtasts"
18)
19
20func tcheckf(t *testing.T, err error, format string, args ...any) {
21 if err != nil {
22 t.Fatalf("%s: %s", fmt.Sprintf(format, args...), err)
23 }
24}
25
26func TestDB(t *testing.T) {
27 mox.Shutdown = ctxbg
28 mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
29 mox.Conf.Static.DataDir = "."
30
31 dbpath := mox.DataDirPath("mtasts.db")
32 os.MkdirAll(filepath.Dir(dbpath), 0770)
33 os.Remove(dbpath)
34 defer os.Remove(dbpath)
35
36 log := mlog.New("mtastsdb", nil)
37
38 if err := Init(false); err != nil {
39 t.Fatalf("init database: %s", err)
40 }
41 defer Close()
42
43 // Mock time.
44 now := time.Now().Round(0)
45 timeNow = func() time.Time { return now }
46 defer func() { timeNow = time.Now }()
47
48 if p, err := lookup(ctxbg, log, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
49 t.Fatalf("expected not found, got %v, %#v", err, p)
50 }
51
52 policy1 := mtasts.Policy{
53 Version: "STSv1",
54 Mode: mtasts.ModeTesting,
55 MX: []mtasts.MX{
56 {Domain: dns.Domain{ASCII: "mx1.example.com"}},
57 {Domain: dns.Domain{ASCII: "mx2.example.com"}},
58 {Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
59 },
60 MaxAgeSeconds: 1296000,
61 }
62 if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "123", &policy1, policy1.String()); err != nil {
63 t.Fatalf("upsert record: %s", err)
64 }
65 if got, err := lookup(ctxbg, log, dns.Domain{ASCII: "example.com"}); err != nil {
66 t.Fatalf("lookup after insert: %s", err)
67 } else if !reflect.DeepEqual(got.Policy, policy1) {
68 t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
69 }
70
71 policy2 := mtasts.Policy{
72 Version: "STSv1",
73 Mode: mtasts.ModeEnforce,
74 MX: []mtasts.MX{
75 {Domain: dns.Domain{ASCII: "mx1.example.com"}},
76 },
77 MaxAgeSeconds: 360000,
78 }
79 if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "124", &policy2, policy2.String()); err != nil {
80 t.Fatalf("upsert record: %s", err)
81 }
82 if got, err := lookup(ctxbg, log, dns.Domain{ASCII: "example.com"}); err != nil {
83 t.Fatalf("lookup after insert: %s", err)
84 } else if !reflect.DeepEqual(got.Policy, policy2) {
85 t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
86 }
87
88 // Check if database holds expected record.
89 records, err := PolicyRecords(ctxbg)
90 tcheckf(t, err, "policyrecords")
91 expRecords := []PolicyRecord{
92 {"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2, policy2.String()},
93 }
94 records[0].Policy = mtasts.Policy{}
95 expRecords[0].Policy = mtasts.Policy{}
96 if !reflect.DeepEqual(records, expRecords) {
97 t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
98 }
99
100 if err := Upsert(ctxbg, dns.Domain{ASCII: "other.example.com"}, "", nil, ""); err != nil {
101 t.Fatalf("upsert record: %s", err)
102 }
103 records, err = PolicyRecords(ctxbg)
104 tcheckf(t, err, "policyrecords")
105 policyNone := mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}
106 expRecords = []PolicyRecord{
107 {"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", policyNone, ""},
108 {"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2, policy2.String()},
109 }
110 if !reflect.DeepEqual(records, expRecords) {
111 t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
112 }
113
114 if _, err := lookup(ctxbg, log, dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
115 t.Fatalf("got %#v, expected ErrBackoff", err)
116 }
117
118 resolver := dns.MockResolver{
119 TXT: map[string][]string{
120 "_mta-sts.example.com.": {"v=STSv1; id=124"},
121 "_mta-sts.other.example.com.": {"v=STSv1; id=1"},
122 "_mta-sts.temperror.example.com.": {""},
123 },
124 Fail: []string{
125 "txt _mta-sts.temperror.example.com.",
126 },
127 }
128
129 testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
130 t.Helper()
131 p, _, fresh, err := Get(ctxbg, log.Logger, resolver, dns.Domain{ASCII: domain})
132 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
133 t.Fatalf("got err %v, expected %v", err, expErr)
134 }
135 if !reflect.DeepEqual(p, expPolicy) || fresh != expFresh {
136 t.Fatalf("got policy %#v, fresh %v, expected %#v, %v", p, fresh, expPolicy, expFresh)
137 }
138 }
139
140 testGet("example.com", &policy2, true, nil)
141 testGet("other.example.com", nil, false, nil) // Back off, already in database.
142 testGet("absent.example.com", nil, true, nil) // No MTA-STS.
143 testGet("temperror.example.com", nil, false, mtasts.ErrDNS)
144
145 // Force refetch of policy, that will fail.
146 mtasts.HTTPClient.Transport = &http.Transport{
147 Dial: func(network, addr string) (net.Conn, error) {
148 return nil, fmt.Errorf("bad")
149 },
150 }
151 defer func() {
152 mtasts.HTTPClient.Transport = nil
153 }()
154 resolver.TXT["_mta-sts.example.com."] = []string{"v=STSv1; id=125"}
155 testGet("example.com", &policy2, false, nil)
156
157 // Cached policy but no longer a DNS record.
158 delete(resolver.TXT, "_mta-sts.example.com.")
159 testGet("example.com", &policy2, false, nil)
160}
161