14 "github.com/mjl-/mox/dns"
15 "github.com/mjl-/mox/mlog"
16 "github.com/mjl-/mox/mox-"
17 "github.com/mjl-/mox/mtasts"
20func tcheckf(t *testing.T, err error, format string, args ...any) {
22 t.Fatalf("%s: %s", fmt.Sprintf(format, args...), err)
26func TestDB(t *testing.T) {
28 mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
29 mox.Conf.Static.DataDir = "."
31 dbpath := mox.DataDirPath("mtasts.db")
32 os.MkdirAll(filepath.Dir(dbpath), 0770)
34 defer os.Remove(dbpath)
36 log := mlog.New("mtastsdb", nil)
38 if err := Init(false); err != nil {
39 t.Fatalf("init database: %s", err)
44 now := time.Now().Round(0)
45 timeNow = func() time.Time { return now }
46 defer func() { timeNow = time.Now }()
48 if p, err := lookup(ctxbg, log, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
49 t.Fatalf("expected not found, got %v, %#v", err, p)
52 policy1 := mtasts.Policy{
54 Mode: mtasts.ModeTesting,
56 {Domain: dns.Domain{ASCII: "mx1.example.com"}},
57 {Domain: dns.Domain{ASCII: "mx2.example.com"}},
58 {Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
60 MaxAgeSeconds: 1296000,
62 if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "123", &policy1, policy1.String()); err != nil {
63 t.Fatalf("upsert record: %s", err)
65 if got, err := lookup(ctxbg, log, dns.Domain{ASCII: "example.com"}); err != nil {
66 t.Fatalf("lookup after insert: %s", err)
67 } else if !reflect.DeepEqual(got.Policy, policy1) {
68 t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
71 policy2 := mtasts.Policy{
73 Mode: mtasts.ModeEnforce,
75 {Domain: dns.Domain{ASCII: "mx1.example.com"}},
77 MaxAgeSeconds: 360000,
79 if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "124", &policy2, policy2.String()); err != nil {
80 t.Fatalf("upsert record: %s", err)
82 if got, err := lookup(ctxbg, log, dns.Domain{ASCII: "example.com"}); err != nil {
83 t.Fatalf("lookup after insert: %s", err)
84 } else if !reflect.DeepEqual(got.Policy, policy2) {
85 t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
88 // Check if database holds expected record.
89 records, err := PolicyRecords(ctxbg)
90 tcheckf(t, err, "policyrecords")
91 expRecords := []PolicyRecord{
92 {"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2, policy2.String()},
94 records[0].Policy = mtasts.Policy{}
95 expRecords[0].Policy = mtasts.Policy{}
96 if !reflect.DeepEqual(records, expRecords) {
97 t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
100 if err := Upsert(ctxbg, dns.Domain{ASCII: "other.example.com"}, "", nil, ""); err != nil {
101 t.Fatalf("upsert record: %s", err)
103 records, err = PolicyRecords(ctxbg)
104 tcheckf(t, err, "policyrecords")
105 policyNone := mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}
106 expRecords = []PolicyRecord{
107 {"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", policyNone, ""},
108 {"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2, policy2.String()},
110 if !reflect.DeepEqual(records, expRecords) {
111 t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
114 if _, err := lookup(ctxbg, log, dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
115 t.Fatalf("got %#v, expected ErrBackoff", err)
118 resolver := dns.MockResolver{
119 TXT: map[string][]string{
120 "_mta-sts.example.com.": {"v=STSv1; id=124"},
121 "_mta-sts.other.example.com.": {"v=STSv1; id=1"},
122 "_mta-sts.temperror.example.com.": {""},
125 "txt _mta-sts.temperror.example.com.",
129 testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
131 p, _, fresh, err := Get(ctxbg, log.Logger, resolver, dns.Domain{ASCII: domain})
132 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
133 t.Fatalf("got err %v, expected %v", err, expErr)
135 if !reflect.DeepEqual(p, expPolicy) || fresh != expFresh {
136 t.Fatalf("got policy %#v, fresh %v, expected %#v, %v", p, fresh, expPolicy, expFresh)
140 testGet("example.com", &policy2, true, nil)
141 testGet("other.example.com", nil, false, nil) // Back off, already in database.
142 testGet("absent.example.com", nil, true, nil) // No MTA-STS.
143 testGet("temperror.example.com", nil, false, mtasts.ErrDNS)
145 // Force refetch of policy, that will fail.
146 mtasts.HTTPClient.Transport = &http.Transport{
147 Dial: func(network, addr string) (net.Conn, error) {
148 return nil, fmt.Errorf("bad")
152 mtasts.HTTPClient.Transport = nil
154 resolver.TXT["_mta-sts.example.com."] = []string{"v=STSv1; id=125"}
155 testGet("example.com", &policy2, false, nil)
157 // Cached policy but no longer a DNS record.
158 delete(resolver.TXT, "_mta-sts.example.com.")
159 testGet("example.com", &policy2, false, nil)