1// Package mtastsdb stores MTA-STS policies for later use.
2//
3// An MTA-STS policy can specify how long it may be cached. By storing a
4// policy, it does not have to be fetched again during email delivery, which
5// makes it harder for attackers to intervene.
6package mtastsdb
7
8import (
9 "context"
10 "crypto/tls"
11 "errors"
12 "fmt"
13 "os"
14 "path/filepath"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/prometheus/client_golang/prometheus"
20 "github.com/prometheus/client_golang/prometheus/promauto"
21
22 "github.com/mjl-/bstore"
23
24 "github.com/mjl-/mox/dns"
25 "github.com/mjl-/mox/mlog"
26 "github.com/mjl-/mox/mox-"
27 "github.com/mjl-/mox/mtasts"
28 "github.com/mjl-/mox/tlsrpt"
29)
30
31var xlog = mlog.New("mtastsdb")
32
33var (
34 metricGet = promauto.NewCounterVec(
35 prometheus.CounterOpts{
36 Name: "mox_mtastsdb_get_total",
37 Help: "Number of Get by result.",
38 },
39 []string{"result"},
40 )
41)
42
43var timeNow = time.Now // Tests override this.
44
45// PolicyRecord is a cached policy or absence of a policy.
46type PolicyRecord struct {
47 Domain string // Domain name, with unicode characters.
48 Inserted time.Time `bstore:"default now"`
49 ValidEnd time.Time
50 LastUpdate time.Time // Policies are refreshed on use and periodically.
51 LastUse time.Time `bstore:"index"`
52 Backoff bool
53 RecordID string // As retrieved from DNS.
54 mtasts.Policy // As retrieved from the well-known HTTPS url.
55
56 // Text that make up the policy, as retrieved. We didn't store this in the past. If
57 // empty, policy can be reconstructed from Policy field. Needed by TLSRPT.
58 PolicyText string
59}
60
61var (
62 // No valid non-expired policy in database.
63 ErrNotFound = errors.New("mtastsdb: policy not found")
64
65 // Indicates an MTA-STS TXT record was fetched recently, but fetching the policy
66 // failed and should not yet be retried.
67 ErrBackoff = errors.New("mtastsdb: policy fetch failed recently")
68)
69
70var DBTypes = []any{PolicyRecord{}} // Types stored in DB.
71var DB *bstore.DB // Exported for backups.
72var mutex sync.Mutex
73
74func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
75 mutex.Lock()
76 defer mutex.Unlock()
77 if DB == nil {
78 p := mox.DataDirPath("mtasts.db")
79 os.MkdirAll(filepath.Dir(p), 0770)
80 db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
81 if err != nil {
82 return nil, err
83 }
84 DB = db
85 }
86 return DB, nil
87}
88
89// Init opens the database and starts a goroutine that refreshes policies in
90// the database, and keeps doing so periodically.
91func Init(refresher bool) error {
92 _, err := database(mox.Shutdown)
93 if err != nil {
94 return err
95 }
96
97 if refresher {
98 // todo: allow us to shut down cleanly?
99 go refresh()
100 }
101
102 return nil
103}
104
105// Close closes the database.
106func Close() {
107 mutex.Lock()
108 defer mutex.Unlock()
109 if DB != nil {
110 err := DB.Close()
111 xlog.Check(err, "closing database")
112 DB = nil
113 }
114}
115
116// lookup looks up a policy for the domain in the database.
117//
118// Only non-expired records are returned.
119//
120// Returns ErrNotFound if record is not present.
121// Returns ErrBackoff if a recent attempt to fetch a record failed.
122func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
123 log := xlog.WithContext(ctx)
124 db, err := database(ctx)
125 if err != nil {
126 return nil, err
127 }
128
129 if domain.IsZero() {
130 return nil, fmt.Errorf("empty domain")
131 }
132 now := timeNow()
133 q := bstore.QueryDB[PolicyRecord](ctx, db)
134 q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
135 q.FilterGreater("ValidEnd", now)
136 pr, err := q.Get()
137 if err == bstore.ErrAbsent {
138 return nil, ErrNotFound
139 } else if err != nil {
140 return nil, err
141 }
142
143 pr.LastUse = now
144 if err := db.Update(ctx, &pr); err != nil {
145 log.Errorx("marking cached mta-sts policy as used in database", err)
146 }
147 if pr.Backoff {
148 return nil, ErrBackoff
149 }
150 return &pr, nil
151}
152
153// Upsert adds the policy to the database, overwriting an existing policy for the domain.
154// Policy can be nil, indicating a failure to fetch the policy.
155func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy, policyText string) error {
156 db, err := database(ctx)
157 if err != nil {
158 return err
159 }
160
161 return db.Write(ctx, func(tx *bstore.Tx) error {
162 pr := PolicyRecord{Domain: domain.Name()}
163 err := tx.Get(&pr)
164 if err != nil && err != bstore.ErrAbsent {
165 return err
166 }
167
168 now := timeNow()
169
170 var p mtasts.Policy
171 if policy != nil {
172 p = *policy
173 } else {
174 // ../rfc/8461:552
175 p.Mode = mtasts.ModeNone
176 p.MaxAgeSeconds = 5 * 60
177 }
178 backoff := policy == nil
179 validEnd := now.Add(time.Duration(p.MaxAgeSeconds) * time.Second)
180
181 if err == bstore.ErrAbsent {
182 pr = PolicyRecord{domain.Name(), now, validEnd, now, now, backoff, recordID, p, policyText}
183 return tx.Insert(&pr)
184 }
185
186 pr.ValidEnd = validEnd
187 pr.LastUpdate = now
188 pr.LastUse = now
189 pr.Backoff = backoff
190 pr.RecordID = recordID
191 pr.Policy = p
192 pr.PolicyText = policyText
193 return tx.Update(&pr)
194 })
195}
196
197// PolicyRecords returns all policies in the database, sorted descending by last
198// use, domain.
199func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
200 db, err := database(ctx)
201 if err != nil {
202 return nil, err
203 }
204 return bstore.QueryDB[PolicyRecord](ctx, db).SortDesc("LastUse", "Domain").List()
205}
206
207// Get retrieves an MTA-STS policy for domain and whether it is fresh.
208//
209// If an error is returned, it should be considered a transient error, e.g. a
210// temporary DNS lookup failure.
211//
212// The returned policy can be nil also when there is no error. In this case, the
213// domain does not implement MTA-STS.
214//
215// If a policy is present in the local database, it is refreshed if needed. If no
216// policy is present for the domain, an attempt is made to fetch the policy and
217// store it in the local database.
218//
219// Some errors are logged but not otherwise returned, e.g. if a new policy is
220// supposedly published but could not be retrieved.
221//
222// Get returns an "sts" or "no-policy-found" in reportResult in most cases (when
223// not a local/internal error). It may add an "sts" result without policy contents
224// ("policy-string") in case of errors while fetching the policy.
225func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy *mtasts.Policy, reportResult tlsrpt.Result, fresh bool, err error) {
226 log := xlog.WithContext(ctx)
227 defer func() {
228 result := "ok"
229 if err != nil && errors.Is(err, ErrBackoff) {
230 result = "backoff"
231 } else if err != nil && errors.Is(err, ErrNotFound) {
232 result = "notfound"
233 } else if err != nil {
234 result = "error"
235 }
236 metricGet.WithLabelValues(result).Inc()
237 log.Debugx("mtastsdb get result", err, mlog.Field("domain", domain), mlog.Field("fresh", fresh))
238 }()
239
240 cachedPolicy, err := lookup(ctx, domain)
241 if err != nil && errors.Is(err, ErrNotFound) {
242 // We don't have a policy for this domain, not even a record that we tried recently
243 // and should backoff. So attempt to fetch policy.
244 nctx, cancel := context.WithTimeout(ctx, time.Minute)
245 defer cancel()
246 record, p, ptext, err := mtasts.Get(nctx, resolver, domain)
247 if err != nil {
248 switch {
249 case errors.Is(err, mtasts.ErrNoRecord) || errors.Is(err, mtasts.ErrMultipleRecords) || errors.Is(err, mtasts.ErrRecordSyntax) || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax):
250 // Remote is not doing MTA-STS, continue below. ../rfc/8461:333 ../rfc/8461:574
251 log.Debugx("interpreting mtasts error to mean remote is not doing mta-sts", err)
252
253 if errors.Is(err, mtasts.ErrNoRecord) {
254 reportResult = tlsrpt.MakeResult(tlsrpt.NoPolicyFound, domain)
255 } else {
256 fd := policyFetchFailureDetails(err)
257 reportResult = tlsrpt.MakeResult(tlsrpt.STS, domain, fd)
258 }
259
260 default:
261 // Interpret as temporary error, e.g. mtasts.ErrDNS, try again later.
262
263 // Temporary DNS error could be an operational issue on our side, but we can still
264 // report it.
265 // Result: ../rfc/8460:594
266 fd := tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, mtasts.TLSReportFailureReason(err))
267 reportResult = tlsrpt.MakeResult(tlsrpt.STS, domain, fd)
268
269 return nil, reportResult, false, fmt.Errorf("lookup up mta-sts policy: %w", err)
270 }
271 } else if p.Mode == mtasts.ModeNone {
272 reportResult = tlsrpt.MakeResult(tlsrpt.NoPolicyFound, domain)
273 } else {
274 reportResult = tlsrpt.Result{Policy: tlsrptPolicy(p, ptext, domain)}
275 }
276
277 // Insert policy into database. If we could not fetch the policy itself, we back
278 // off for 5 minutes. ../rfc/8461:555
279 if err == nil || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax) {
280 var recordID string
281 if record != nil {
282 recordID = record.ID
283 }
284 if err := Upsert(ctx, domain, recordID, p, ptext); err != nil {
285 log.Errorx("inserting policy into cache, continuing", err)
286 }
287 }
288
289 return p, reportResult, true, nil
290 } else if err != nil && errors.Is(err, ErrBackoff) {
291 // ../rfc/8461:552
292 // We recently failed to fetch a policy, act as if MTA-STS is not implemented.
293 // Result: ../rfc/8460:594
294 fd := tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, "back-off-after-recent-fetch-error")
295 reportResult = tlsrpt.MakeResult(tlsrpt.STS, domain, fd)
296 return nil, reportResult, false, nil
297 } else if err != nil {
298 // We don't add the result to the report, this is an internal error.
299 return nil, reportResult, false, fmt.Errorf("looking up mta-sts policy in cache: %w", err)
300 }
301
302 // Policy was found in database. Check in DNS it is still fresh.
303 policy = &cachedPolicy.Policy
304 nctx, cancel := context.WithTimeout(ctx, 30*time.Second)
305 defer cancel()
306 record, _, err := mtasts.LookupRecord(nctx, resolver, domain)
307 if err != nil {
308 if errors.Is(err, mtasts.ErrNoRecord) {
309 if policy.Mode != mtasts.ModeNone {
310 log.Errorx("no mtasts dns record while checking non-none policy for freshness, either domain owner removed mta-sts without phasing out policy with a none-policy for period of previous max-age, or this could be an attempt to downgrade to connection without mtasts, continuing with previous policy", err)
311 }
312 // else, policy will be removed by periodic refresher in the near future.
313 } else {
314 // Could be a temporary DNS or configuration error.
315 log.Errorx("checking for freshness of cached mta-sts dns txt record for domain, continuing with previously cached policy", err)
316 }
317
318 // Result: ../rfc/8460:594
319 fd := tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, mtasts.TLSReportFailureReason(err))
320 if policy.Mode != mtasts.ModeNone {
321 fd.FailureReasonCode += "+fallback-to-cached-policy"
322 }
323 reportResult = tlsrpt.Result{
324 Policy: tlsrptPolicy(policy, cachedPolicy.PolicyText, domain),
325 FailureDetails: []tlsrpt.FailureDetails{fd},
326 }
327 return policy, reportResult, false, nil
328 } else if record.ID == cachedPolicy.RecordID && cachedPolicy.PolicyText != "" {
329 // In the past, we didn't store the raw policy lines in cachedPolicy.Lines. We only
330 // stop now if we do have policy lines in the cache.
331 reportResult = tlsrpt.Result{Policy: tlsrptPolicy(policy, cachedPolicy.PolicyText, domain)}
332 return policy, reportResult, true, nil
333 }
334
335 // New policy should be available, or we are fetching the policy again because we
336 // didn't store the raw policy lines in the past.
337 nctx, cancel = context.WithTimeout(ctx, 30*time.Second)
338 defer cancel()
339 p, ptext, err := mtasts.FetchPolicy(nctx, domain)
340 if err != nil {
341 log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
342
343 fd := policyFetchFailureDetails(err)
344 fd.FailureReasonCode += "+fallback-to-cached-policy"
345 reportResult = tlsrpt.Result{
346 Policy: tlsrptPolicy(policy, cachedPolicy.PolicyText, domain),
347 FailureDetails: []tlsrpt.FailureDetails{fd},
348 }
349 return policy, reportResult, false, nil
350 }
351 if err := Upsert(ctx, domain, record.ID, p, ptext); err != nil {
352 log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
353 }
354 reportResult = tlsrpt.Result{Policy: tlsrptPolicy(p, ptext, domain)}
355 return p, reportResult, true, nil
356}
357
358func policyFetchFailureDetails(err error) tlsrpt.FailureDetails {
359 var verificationErr *tls.CertificateVerificationError
360 if errors.As(err, &verificationErr) {
361 resultType, reasonCode := tlsrpt.TLSFailureDetails(verificationErr)
362 // Result: ../rfc/8460:601
363 reason := string(resultType)
364 if reasonCode != "" {
365 reason += "+" + reasonCode
366 }
367 return tlsrpt.Details(tlsrpt.ResultSTSWebPKIInvalid, reason)
368 } else if errors.Is(err, mtasts.ErrPolicySyntax) {
369 // Result: ../rfc/8460:598
370 return tlsrpt.Details(tlsrpt.ResultSTSPolicyInvalid, mtasts.TLSReportFailureReason(err))
371 }
372 // Result: ../rfc/8460:594
373 return tlsrpt.Details(tlsrpt.ResultSTSPolicyFetch, mtasts.TLSReportFailureReason(err))
374}
375
376func tlsrptPolicy(p *mtasts.Policy, policyText string, domain dns.Domain) tlsrpt.ResultPolicy {
377 if policyText == "" {
378 // We didn't always store original policy lines. Reconstruct.
379 policyText = p.String()
380 }
381 lines := strings.Split(strings.TrimSuffix(policyText, "\n"), "\n")
382 for i, line := range lines {
383 lines[i] = strings.TrimSuffix(line, "\r")
384 }
385
386 rp := tlsrpt.ResultPolicy{
387 Type: tlsrpt.STS,
388 Domain: domain.ASCII,
389 String: lines,
390 }
391 rp.MXHost = make([]string, len(p.MX))
392 for i, mx := range p.MX {
393 s := mx.Domain.ASCII
394 if mx.Wildcard {
395 s = "*." + s
396 }
397 rp.MXHost[i] = s
398 }
399 return rp
400}
401