1// Package mtastsdb stores MTA-STS policies for later use.
2//
3// An MTA-STS policy can specify how long it may be cached. By storing a
4// policy, it does not have to be fetched again during email delivery, which
5// makes it harder for attackers to intervene.
6package mtastsdb
7
8import (
9 "context"
10 "crypto/tls"
11 "errors"
12 "fmt"
13 "log/slog"
14 "os"
15 "path/filepath"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/prometheus/client_golang/prometheus"
21 "github.com/prometheus/client_golang/prometheus/promauto"
22
23 "github.com/mjl-/bstore"
24
25 "github.com/mjl-/mox/dns"
26 "github.com/mjl-/mox/mlog"
27 "github.com/mjl-/mox/mox-"
28 "github.com/mjl-/mox/mtasts"
29 "github.com/mjl-/mox/tlsrpt"
30)
31
32var (
33 metricGet = promauto.NewCounterVec(
34 prometheus.CounterOpts{
35 Name: "mox_mtastsdb_get_total",
36 Help: "Number of Get by result.",
37 },
38 []string{"result"},
39 )
40)
41
42var timeNow = time.Now // Tests override this.
43
44// PolicyRecord is a cached policy or absence of a policy.
45type PolicyRecord struct {
46 Domain string // Domain name, with unicode characters.
47 Inserted time.Time `bstore:"default now"`
48 ValidEnd time.Time
49 LastUpdate time.Time // Policies are refreshed on use and periodically.
50 LastUse time.Time `bstore:"index"`
51 Backoff bool
52 RecordID string // As retrieved from DNS.
53 mtasts.Policy // As retrieved from the well-known HTTPS url.
54
55 // Text that make up the policy, as retrieved. We didn't store this in the past. If
56 // empty, policy can be reconstructed from Policy field. Needed by TLSRPT.
57 PolicyText string
58}
59
60var (
61 // No valid non-expired policy in database.
62 ErrNotFound = errors.New("mtastsdb: policy not found")
63
64 // Indicates an MTA-STS TXT record was fetched recently, but fetching the policy
65 // failed and should not yet be retried.
66 ErrBackoff = errors.New("mtastsdb: policy fetch failed recently")
67)
68
69var DBTypes = []any{PolicyRecord{}} // Types stored in DB.
70var DB *bstore.DB // Exported for backups.
71var mutex sync.Mutex
72
73func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
74 mutex.Lock()
75 defer mutex.Unlock()
76 if DB == nil {
77 p := mox.DataDirPath("mtasts.db")
78 os.MkdirAll(filepath.Dir(p), 0770)
79 db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
80 if err != nil {
81 return nil, err
82 }
83 DB = db
84 }
85 return DB, nil
86}
87
88// Init opens the database and starts a goroutine that refreshes policies in
89// the database, and keeps doing so periodically.
90func Init(refresher bool) error {
91 _, err := database(mox.Shutdown)
92 if err != nil {
93 return err
94 }
95
96 if refresher {
97 // todo: allow us to shut down cleanly?
98 go refresh()
99 }
100
101 return nil
102}
103
104// Close closes the database.
105func Close() {
106 mutex.Lock()
107 defer mutex.Unlock()
108 if DB != nil {
109 err := DB.Close()
110 mlog.New("mtastsdb", nil).Check(err, "closing database")
111 DB = nil
112 }
113}
114
115// lookup looks up a policy for the domain in the database.
116//
117// Only non-expired records are returned.
118//
119// Returns ErrNotFound if record is not present.
120// Returns ErrBackoff if a recent attempt to fetch a record failed.
121func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord, error) {
122 db, err := database(ctx)
123 if err != nil {
124 return nil, err
125 }
126
127 if domain.IsZero() {
128 return nil, fmt.Errorf("empty domain")
129 }
130 now := timeNow()
131 q := bstore.QueryDB[PolicyRecord](ctx, db)
132 q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
133 q.FilterGreater("ValidEnd", now)
134 pr, err := q.Get()
135 if err == bstore.ErrAbsent {
136 return nil, ErrNotFound
137 } else if err != nil {
138 return nil, err
139 }
140
141 pr.LastUse = now
142 if err := db.Update(ctx, &pr); err != nil {
143 log.Errorx("marking cached mta-sts policy as used in database", err)
144 }
145 if pr.Backoff {
146 return nil, ErrBackoff
147 }
148 return &pr, nil
149}
150
151// Upsert adds the policy to the database, overwriting an existing policy for the domain.
152// Policy can be nil, indicating a failure to fetch the policy.
153func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy, policyText string) error {
154 db, err := database(ctx)
155 if err != nil {
156 return err
157 }
158
159 return db.Write(ctx, func(tx *bstore.Tx) error {
160 pr := PolicyRecord{Domain: domain.Name()}
161 err := tx.Get(&pr)
162 if err != nil && err != bstore.ErrAbsent {
163 return err
164 }
165
166 now := timeNow()
167
168 var p mtasts.Policy
169 if policy != nil {
170 p = *policy
171 } else {
172 // ../rfc/8461:552
173 p.Mode = mtasts.ModeNone
174 p.MaxAgeSeconds = 5 * 60
175 }
176 backoff := policy == nil
177 validEnd := now.Add(time.Duration(p.MaxAgeSeconds) * time.Second)
178
179 if err == bstore.ErrAbsent {
180 pr = PolicyRecord{domain.Name(), now, validEnd, now, now, backoff, recordID, p, policyText}
181 return tx.Insert(&pr)
182 }
183
184 pr.ValidEnd = validEnd
185 pr.LastUpdate = now
186 pr.LastUse = now
187 pr.Backoff = backoff
188 pr.RecordID = recordID
189 pr.Policy = p
190 pr.PolicyText = policyText
191 return tx.Update(&pr)
192 })
193}
194
195// PolicyRecords returns all policies in the database, sorted descending by last
196// use, domain.
197func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
198 db, err := database(ctx)
199 if err != nil {
200 return nil, err
201 }
202 return bstore.QueryDB[PolicyRecord](ctx, db).SortDesc("LastUse", "Domain").List()
203}
204
205// Get retrieves an MTA-STS policy for domain and whether it is fresh.
206//
207// If an error is returned, it should be considered a transient error, e.g. a
208// temporary DNS lookup failure.
209//
210// The returned policy can be nil also when there is no error. In this case, the
211// domain does not implement MTA-STS.
212//
213// If a policy is present in the local database, it is refreshed if needed. If no
214// policy is present for the domain, an attempt is made to fetch the policy and
215// store it in the local database.
216//
217// Some errors are logged but not otherwise returned, e.g. if a new policy is
218// supposedly published but could not be retrieved.
219//
220// Get returns an "sts" or "no-policy-found" in reportResult in most cases (when
221// not a local/internal error). It may add an "sts" result without policy contents
222// ("policy-string") in case of errors while fetching the policy.
223func Get(ctx context.Context, elog *slog.Logger, resolver dns.Resolver, domain dns.Domain) (policy *mtasts.Policy, reportResult tlsrpt.Result, fresh bool, err error) {
224 log := mlog.New("mtastsdb", elog)
225 defer func() {
226 result := "ok"
227 if err != nil && errors.Is(err, ErrBackoff) {
228 result = "backoff"
229 } else if err != nil && errors.Is(err, ErrNotFound) {
230 result = "notfound"
231 } else if err != nil {
232 result = "error"
233 }
234 metricGet.WithLabelValues(result).Inc()
235 log.Debugx("mtastsdb get result", err, slog.Any("domain", domain), slog.Bool("fresh", fresh))
236 }()
237
238 cachedPolicy, err := lookup(ctx, log, domain)
239 if err != nil && errors.Is(err, ErrNotFound) {
240 // We don't have a policy for this domain, not even a record that we tried recently
241 // and should backoff. So attempt to fetch policy.
242 nctx, cancel := context.WithTimeout(ctx, time.Minute)
243 defer cancel()
244 record, p, ptext, err := mtasts.Get(nctx, log.Logger, resolver, domain)
245 if err != nil {
246 switch {
247 case errors.Is(err, mtasts.ErrNoRecord) || errors.Is(err, mtasts.ErrMultipleRecords) || errors.Is(err, mtasts.ErrRecordSyntax) || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax):
248 // Remote is not doing MTA-STS, continue below. ../rfc/8461:333 ../rfc/8461:574
249 log.Debugx("interpreting mtasts error to mean remote is not doing mta-sts", err)
250
251 if errors.Is(err, mtasts.ErrNoRecord) {
252 reportResult = tlsrpt.MakeResult(tlsrpt.NoPolicyFound, domain)
253 } else {
254 fd := policyFetchFailureDetails(err)
255 reportResult = tlsrpt.MakeResult(tlsrpt.STS, domain, fd)
256 }
257
258 default:
259 // Interpret as temporary error, e.g. mtasts.ErrDNS, try again later.
260
261 // Temporary DNS error could be an operational issue on our side, but we can still
262 // report it.
263 // Result: ../rfc/8460:594
264 fd := tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, mtasts.TLSReportFailureReason(err))
265 reportResult = tlsrpt.MakeResult(tlsrpt.STS, domain, fd)
266
267 return nil, reportResult, false, fmt.Errorf("lookup up mta-sts policy: %w", err)
268 }
269 } else if p.Mode == mtasts.ModeNone {
270 reportResult = tlsrpt.MakeResult(tlsrpt.NoPolicyFound, domain)
271 } else {
272 reportResult = tlsrpt.Result{Policy: tlsrptPolicy(p, ptext, domain)}
273 }
274
275 // Insert policy into database. If we could not fetch the policy itself, we back
276 // off for 5 minutes. ../rfc/8461:555
277 if err == nil || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax) {
278 var recordID string
279 if record != nil {
280 recordID = record.ID
281 }
282 if err := Upsert(ctx, domain, recordID, p, ptext); err != nil {
283 log.Errorx("inserting policy into cache, continuing", err)
284 }
285 }
286
287 return p, reportResult, true, nil
288 } else if err != nil && errors.Is(err, ErrBackoff) {
289 // ../rfc/8461:552
290 // We recently failed to fetch a policy, act as if MTA-STS is not implemented.
291 // Result: ../rfc/8460:594
292 fd := tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, "back-off-after-recent-fetch-error")
293 reportResult = tlsrpt.MakeResult(tlsrpt.STS, domain, fd)
294 return nil, reportResult, false, nil
295 } else if err != nil {
296 // We don't add the result to the report, this is an internal error.
297 return nil, reportResult, false, fmt.Errorf("looking up mta-sts policy in cache: %w", err)
298 }
299
300 // Policy was found in database. Check in DNS it is still fresh.
301 policy = &cachedPolicy.Policy
302 nctx, cancel := context.WithTimeout(ctx, 30*time.Second)
303 defer cancel()
304 record, _, err := mtasts.LookupRecord(nctx, log.Logger, resolver, domain)
305 if err != nil {
306 if errors.Is(err, mtasts.ErrNoRecord) {
307 if policy.Mode != mtasts.ModeNone {
308 log.Errorx("no mtasts dns record while checking non-none policy for freshness, either domain owner removed mta-sts without phasing out policy with a none-policy for period of previous max-age, or this could be an attempt to downgrade to connection without mtasts, continuing with previous policy", err)
309 }
310 // else, policy will be removed by periodic refresher in the near future.
311 } else {
312 // Could be a temporary DNS or configuration error.
313 log.Errorx("checking for freshness of cached mta-sts dns txt record for domain, continuing with previously cached policy", err)
314 }
315
316 // Result: ../rfc/8460:594
317 fd := tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, mtasts.TLSReportFailureReason(err))
318 if policy.Mode != mtasts.ModeNone {
319 fd.FailureReasonCode += "+fallback-to-cached-policy"
320 }
321 reportResult = tlsrpt.Result{
322 Policy: tlsrptPolicy(policy, cachedPolicy.PolicyText, domain),
323 FailureDetails: []tlsrpt.FailureDetails{fd},
324 }
325 return policy, reportResult, false, nil
326 } else if record.ID == cachedPolicy.RecordID && cachedPolicy.PolicyText != "" {
327 // In the past, we didn't store the raw policy lines in cachedPolicy.Lines. We only
328 // stop now if we do have policy lines in the cache.
329 reportResult = tlsrpt.Result{Policy: tlsrptPolicy(policy, cachedPolicy.PolicyText, domain)}
330 return policy, reportResult, true, nil
331 }
332
333 // New policy should be available, or we are fetching the policy again because we
334 // didn't store the raw policy lines in the past.
335 nctx, cancel = context.WithTimeout(ctx, 30*time.Second)
336 defer cancel()
337 p, ptext, err := mtasts.FetchPolicy(nctx, log.Logger, domain)
338 if err != nil {
339 log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
340
341 fd := policyFetchFailureDetails(err)
342 fd.FailureReasonCode += "+fallback-to-cached-policy"
343 reportResult = tlsrpt.Result{
344 Policy: tlsrptPolicy(policy, cachedPolicy.PolicyText, domain),
345 FailureDetails: []tlsrpt.FailureDetails{fd},
346 }
347 return policy, reportResult, false, nil
348 }
349 if err := Upsert(ctx, domain, record.ID, p, ptext); err != nil {
350 log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
351 }
352 reportResult = tlsrpt.Result{Policy: tlsrptPolicy(p, ptext, domain)}
353 return p, reportResult, true, nil
354}
355
356func policyFetchFailureDetails(err error) tlsrpt.FailureDetails {
357 var verificationErr *tls.CertificateVerificationError
358 if errors.As(err, &verificationErr) {
359 resultType, reasonCode := tlsrpt.TLSFailureDetails(verificationErr)
360 // Result: ../rfc/8460:601
361 reason := string(resultType)
362 if reasonCode != "" {
363 reason += "+" + reasonCode
364 }
365 return tlsrpt.Details(tlsrpt.ResultSTSWebPKIInvalid, reason)
366 } else if errors.Is(err, mtasts.ErrPolicySyntax) {
367 // Result: ../rfc/8460:598
368 return tlsrpt.Details(tlsrpt.ResultSTSPolicyInvalid, mtasts.TLSReportFailureReason(err))
369 }
370 // Result: ../rfc/8460:594
371 return tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, mtasts.TLSReportFailureReason(err))
372}
373
374func tlsrptPolicy(p *mtasts.Policy, policyText string, domain dns.Domain) tlsrpt.ResultPolicy {
375 if policyText == "" {
376 // We didn't always store original policy lines. Reconstruct.
377 policyText = p.String()
378 }
379 lines := strings.Split(strings.TrimSuffix(policyText, "\n"), "\n")
380 for i, line := range lines {
381 lines[i] = strings.TrimSuffix(line, "\r")
382 }
383
384 rp := tlsrpt.ResultPolicy{
385 Type: tlsrpt.STS,
386 Domain: domain.ASCII,
387 String: lines,
388 }
389 rp.MXHost = make([]string, len(p.MX))
390 for i, mx := range p.MX {
391 s := mx.Domain.ASCII
392 if mx.Wildcard {
393 s = "*." + s
394 }
395 rp.MXHost[i] = s
396 }
397 return rp
398}
399