1package mtastsdb
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 mathrand "math/rand"
9 "runtime/debug"
10 "time"
11
12 "github.com/mjl-/bstore"
13
14 "github.com/mjl-/mox/dns"
15 "github.com/mjl-/mox/metrics"
16 "github.com/mjl-/mox/mlog"
17 "github.com/mjl-/mox/mox-"
18 "github.com/mjl-/mox/mtasts"
19)
20
21func refresh() int {
22 interval := 24 * time.Hour
23 ticker := time.NewTicker(interval)
24 defer ticker.Stop()
25
26 var refreshed int
27
28 // Pro-actively refresh policies every 24 hours. ../rfc/8461:583
29 for {
30 ticker.Reset(interval)
31
32 log := mlog.New("mtastsdb", nil).WithCid(mox.Cid())
33 n, err := refresh1(mox.Context, log, dns.StrictResolver{Pkg: "mtastsdb"}, time.Sleep)
34 log.Check(err, "periodic refresh of cached mtasts policies")
35 if n > 0 {
36 refreshed += n
37 }
38
39 select {
40 case <-mox.Shutdown.Done():
41 return refreshed
42 case <-ticker.C:
43 }
44 }
45}
46
47// refresh policies that have not been updated in the past 12 hours and remove
48// policies not used for 180 days. We start with the first domain immediately, so
49// an admin can see any (configuration) issues that are logged. We spread the
50// refreshes evenly over the next 3 hours, randomizing the domains, and we add some
51// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
52// refresh doesn't mess up the timing.
53func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
54 db, err := database(ctx)
55 if err != nil {
56 return 0, err
57 }
58
59 now := timeNow()
60 qdel := bstore.QueryDB[PolicyRecord](ctx, db)
61 qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
62 if _, err := qdel.Delete(); err != nil {
63 return 0, fmt.Errorf("deleting old unused policies: %s", err)
64 }
65
66 qup := bstore.QueryDB[PolicyRecord](ctx, db)
67 qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
68 prs, err := qup.List()
69 if err != nil {
70 return 0, fmt.Errorf("querying policies to refresh: %s", err)
71 }
72
73 if len(prs) == 0 {
74 // Nothing to do.
75 return 0, nil
76 }
77
78 // Randomize list.
79 rand := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
80 for i := range prs {
81 if i == 0 {
82 continue
83 }
84 j := rand.Intn(i + 1)
85 prs[i], prs[j] = prs[j], prs[i]
86 }
87
88 // Launch goroutine with the refresh.
89 log.Debug("will refresh mta-sts policies over next 3 hours", slog.Int("count", len(prs)))
90 start := timeNow()
91 for i, pr := range prs {
92 go refreshDomain(ctx, log, db, resolver, pr)
93 if i < len(prs)-1 {
94 interval := 3 * int64(time.Hour) / int64(len(prs)-1)
95 extra := time.Duration(rand.Int63n(interval) - interval/2)
96 next := start.Add(time.Duration(int64(i+1)*interval) + extra)
97 d := next.Sub(timeNow())
98 if d > 0 {
99 sleep(d)
100 }
101 }
102 }
103 return len(prs), nil
104}
105
106func refreshDomain(ctx context.Context, log mlog.Log, db *bstore.DB, resolver dns.Resolver, pr PolicyRecord) {
107 defer func() {
108 x := recover()
109 if x != nil {
110 // Should not happen, but make sure errors don't take down the application.
111 log.Error("refresh1", slog.Any("panic", x))
112 debug.PrintStack()
113 metrics.PanicInc(metrics.Mtastsdb)
114 }
115 }()
116
117 ctx, cancel := context.WithTimeout(ctx, time.Minute)
118 defer cancel()
119
120 d, err := dns.ParseDomain(pr.Domain)
121 if err != nil {
122 log.Errorx("refreshing mta-sts policy: parsing policy domain", err, slog.Any("domain", d))
123 return
124 }
125 log.Debug("refreshing mta-sts policy for domain", slog.Any("domain", d))
126 record, _, err := mtasts.LookupRecord(ctx, log.Logger, resolver, d)
127 if err == nil && record.ID == pr.RecordID {
128 qup := bstore.QueryDB[PolicyRecord](ctx, db)
129 qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
130 now := timeNow()
131 update := PolicyRecord{
132 LastUpdate: now,
133 ValidEnd: now.Add(time.Duration(pr.MaxAgeSeconds) * time.Second),
134 }
135 if n, err := qup.UpdateNonzero(update); err != nil {
136 log.Errorx("updating refreshed, unmodified policy in database", err)
137 } else if n != 1 {
138 log.Info("expected to update 1 policy after refresh", slog.Int("count", n))
139 }
140 return
141 }
142 if err != nil && pr.Mode == mtasts.ModeNone {
143 if errors.Is(err, mtasts.ErrNoRecord) {
144 // Policy was in mode "none". Now it doesn't have a policy anymore. Remove from our
145 // database so we don't keep refreshing it.
146 err := db.Delete(ctx, &pr)
147 log.Check(err, "removing mta-sts policy with mode none, dns record is gone")
148 }
149 // Else, don't bother operator with temporary error about policy none.
150 // ../rfc/8461:587
151 return
152 } else if err != nil {
153 log.Errorx("looking up mta-sts record for domain", err, slog.Any("domain", d))
154 // Try to fetch new policy. It could be just DNS that is down. We don't want to let our policy expire.
155 }
156
157 p, _, err := mtasts.FetchPolicy(ctx, log.Logger, d)
158 if err != nil {
159 if !errors.Is(err, mtasts.ErrNoPolicy) || pr.Mode != mtasts.ModeNone {
160 log.Errorx("refreshing mtasts policy for domain", err, slog.Any("domain", d))
161 }
162 return
163 }
164 now := timeNow()
165 update := map[string]any{
166 "LastUpdate": now,
167 "ValidEnd": now.Add(time.Duration(p.MaxAgeSeconds) * time.Second),
168 "Backoff": false,
169 "Policy": *p,
170 }
171 if record != nil {
172 update["RecordID"] = record.ID
173 }
174 qup := bstore.QueryDB[PolicyRecord](ctx, db)
175 qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
176 if n, err := qup.UpdateFields(update); err != nil {
177 log.Errorx("updating refreshed, modified policy in database", err)
178 } else if n != 1 {
179 log.Info("updating refreshed, did not update 1 policy", slog.Int("count", n))
180 }
181}
182