1package mtastsdb
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 mathrand "math/rand"
8 "runtime/debug"
9 "time"
10
11 "github.com/mjl-/bstore"
12
13 "github.com/mjl-/mox/dns"
14 "github.com/mjl-/mox/metrics"
15 "github.com/mjl-/mox/mlog"
16 "github.com/mjl-/mox/mox-"
17 "github.com/mjl-/mox/mtasts"
18)
19
20func refresh() int {
21 interval := 24 * time.Hour
22 ticker := time.NewTicker(interval)
23 defer ticker.Stop()
24
25 var refreshed int
26
27 // Pro-actively refresh policies every 24 hours. ../rfc/8461:583
28 for {
29 ticker.Reset(interval)
30
31 ctx := context.WithValue(mox.Context, mlog.CidKey, mox.Cid())
32 n, err := refresh1(ctx, dns.StrictResolver{Pkg: "mtastsdb"}, time.Sleep)
33 if err != nil {
34 xlog.WithContext(ctx).Errorx("periodic refresh of cached mtasts policies", err)
35 }
36 if n > 0 {
37 refreshed += n
38 }
39
40 select {
41 case <-mox.Shutdown.Done():
42 return refreshed
43 case <-ticker.C:
44 }
45 }
46}
47
48// refresh policies that have not been updated in the past 12 hours and remove
49// policies not used for 180 days. We start with the first domain immediately, so
50// an admin can see any (configuration) issues that are logged. We spread the
51// refreshes evenly over the next 3 hours, randomizing the domains, and we add some
52// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
53// refresh doesn't mess up the timing.
54func refresh1(ctx context.Context, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
55 db, err := database(ctx)
56 if err != nil {
57 return 0, err
58 }
59
60 now := timeNow()
61 qdel := bstore.QueryDB[PolicyRecord](ctx, db)
62 qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
63 if _, err := qdel.Delete(); err != nil {
64 return 0, fmt.Errorf("deleting old unused policies: %s", err)
65 }
66
67 qup := bstore.QueryDB[PolicyRecord](ctx, db)
68 qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
69 prs, err := qup.List()
70 if err != nil {
71 return 0, fmt.Errorf("querying policies to refresh: %s", err)
72 }
73
74 if len(prs) == 0 {
75 // Nothing to do.
76 return 0, nil
77 }
78
79 // Randomize list.
80 rand := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
81 for i := range prs {
82 if i == 0 {
83 continue
84 }
85 j := rand.Intn(i + 1)
86 prs[i], prs[j] = prs[j], prs[i]
87 }
88
89 // Launch goroutine with the refresh.
90 xlog.WithContext(ctx).Debug("will refresh mta-sts policies over next 3 hours", mlog.Field("count", len(prs)))
91 start := timeNow()
92 for i, pr := range prs {
93 go refreshDomain(ctx, db, resolver, pr)
94 if i < len(prs)-1 {
95 interval := 3 * int64(time.Hour) / int64(len(prs)-1)
96 extra := time.Duration(rand.Int63n(interval) - interval/2)
97 next := start.Add(time.Duration(int64(i+1)*interval) + extra)
98 d := next.Sub(timeNow())
99 if d > 0 {
100 sleep(d)
101 }
102 }
103 }
104 return len(prs), nil
105}
106
107func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr PolicyRecord) {
108 log := xlog.WithContext(ctx)
109 defer func() {
110 x := recover()
111 if x != nil {
112 // Should not happen, but make sure errors don't take down the application.
113 log.Error("refresh1", mlog.Field("panic", x))
114 debug.PrintStack()
115 metrics.PanicInc(metrics.Mtastsdb)
116 }
117 }()
118
119 ctx, cancel := context.WithTimeout(ctx, time.Minute)
120 defer cancel()
121
122 d, err := dns.ParseDomain(pr.Domain)
123 if err != nil {
124 log.Errorx("refreshing mta-sts policy: parsing policy domain", err, mlog.Field("domain", d))
125 return
126 }
127 log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
128 record, _, err := mtasts.LookupRecord(ctx, resolver, d)
129 if err == nil && record.ID == pr.RecordID {
130 qup := bstore.QueryDB[PolicyRecord](ctx, db)
131 qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
132 now := timeNow()
133 update := PolicyRecord{
134 LastUpdate: now,
135 ValidEnd: now.Add(time.Duration(pr.MaxAgeSeconds) * time.Second),
136 }
137 if n, err := qup.UpdateNonzero(update); err != nil {
138 log.Errorx("updating refreshed, unmodified policy in database", err)
139 } else if n != 1 {
140 log.Info("expected to update 1 policy after refresh", mlog.Field("count", n))
141 }
142 return
143 }
144 if err != nil && pr.Mode == mtasts.ModeNone {
145 if errors.Is(err, mtasts.ErrNoRecord) {
146 // Policy was in mode "none". Now it doesn't have a policy anymore. Remove from our
147 // database so we don't keep refreshing it.
148 err := db.Delete(ctx, &pr)
149 log.Check(err, "removing mta-sts policy with mode none, dns record is gone")
150 }
151 // Else, don't bother operator with temporary error about policy none.
152 // ../rfc/8461:587
153 return
154 } else if err != nil {
155 log.Errorx("looking up mta-sts record for domain", err, mlog.Field("domain", d))
156 // Try to fetch new policy. It could be just DNS that is down. We don't want to let our policy expire.
157 }
158
159 p, _, err := mtasts.FetchPolicy(ctx, d)
160 if err != nil {
161 if !errors.Is(err, mtasts.ErrNoPolicy) || pr.Mode != mtasts.ModeNone {
162 log.Errorx("refreshing mtasts policy for domain", err, mlog.Field("domain", d))
163 }
164 return
165 }
166 now := timeNow()
167 update := map[string]any{
168 "LastUpdate": now,
169 "ValidEnd": now.Add(time.Duration(p.MaxAgeSeconds) * time.Second),
170 "Backoff": false,
171 "Policy": *p,
172 }
173 if record != nil {
174 update["RecordID"] = record.ID
175 }
176 qup := bstore.QueryDB[PolicyRecord](ctx, db)
177 qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
178 if n, err := qup.UpdateFields(update); err != nil {
179 log.Errorx("updating refreshed, modified policy in database", err)
180 } else if n != 1 {
181 log.Info("updating refreshed, did not update 1 policy", mlog.Field("count", n))
182 }
183}
184