1package dns
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net"
9 "os"
10 "runtime"
11 "strings"
12 "time"
13
14 "github.com/mjl-/adns"
15
16 "github.com/mjl-/mox/mlog"
17 "github.com/mjl-/mox/stub"
18)
19
20// todo future: replace with a dnssec capable resolver
21// todo future: change to interface that is closer to DNS. 1. expose nxdomain vs success with zero entries: nxdomain means the name does not exist for any dns resource record type, success with zero records means the name exists for other types than the requested type; 2. add ability to not follow cname records when resolving. the net resolver automatically follows cnames for LookupHost, LookupIP, LookupIPAddr. when resolving names found in mx records, we explicitly must not follow cnames. that seems impossible at the moment. 3. when looking up a cname, actually lookup the record? "net" LookupCNAME will return the requested name with no error if there is no CNAME record. because it returns the canonical name.
22// todo future: add option to not use anything in the cache, for the admin pages where you check the latest DNS settings, ignoring old cached info.
23
24func init() {
25 net.DefaultResolver.StrictErrors = true
26}
27
28var (
29 MetricLookup stub.HistogramVec = stub.HistogramVecIgnore{}
30)
31
32// Resolver is the interface strict resolver implements.
33type Resolver interface {
34 LookupPort(ctx context.Context, network, service string) (port int, err error)
35 LookupAddr(ctx context.Context, addr string) ([]string, adns.Result, error) // Always returns absolute names, with trailing dot.
36 LookupCNAME(ctx context.Context, host string) (string, adns.Result, error) // NOTE: returns an error if no CNAME record is present.
37 LookupHost(ctx context.Context, host string) ([]string, adns.Result, error)
38 LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error)
39 LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, adns.Result, error)
40 LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error)
41 LookupNS(ctx context.Context, name string) ([]*net.NS, adns.Result, error)
42 LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, adns.Result, error)
43 LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error)
44 LookupTLSA(ctx context.Context, port int, protocol, host string) ([]adns.TLSA, adns.Result, error)
45}
46
47// WithPackage sets Pkg on resolver if it is a StrictResolve and does not have a package set yet.
48func WithPackage(resolver Resolver, name string) Resolver {
49 r, ok := resolver.(StrictResolver)
50 if ok && r.Pkg == "" {
51 nr := r
52 r.Pkg = name
53 return nr
54 }
55 return resolver
56}
57
58// StrictResolver is a net.Resolver that enforces that DNS names end with a dot,
59// preventing "search"-relative lookups.
60type StrictResolver struct {
61 Pkg string // Name of subsystem that is making DNS requests, for metrics.
62 Resolver *adns.Resolver // Where the actual lookups are done. If nil, adns.DefaultResolver is used for lookups.
63 Log *slog.Logger
64}
65
66func (r StrictResolver) log() mlog.Log {
67 pkg := r.Pkg
68 if pkg == "" {
69 pkg = "dns"
70 }
71 return mlog.New(pkg, r.Log)
72}
73
74var _ Resolver = StrictResolver{}
75
76var ErrRelativeDNSName = errors.New("dns: host to lookup must be absolute, ending with a dot")
77
78func metricLookupObserve(pkg, typ string, err error, start time.Time) {
79 var result string
80 var dnsErr *adns.DNSError
81 switch {
82 case err == nil:
83 result = "ok"
84 case errors.As(err, &dnsErr) && dnsErr.IsNotFound:
85 result = "nxdomain"
86 case errors.As(err, &dnsErr) && dnsErr.IsTemporary:
87 result = "temporary"
88 case errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, context.DeadlineExceeded) || errors.As(err, &dnsErr) && dnsErr.IsTimeout:
89 result = "timeout"
90 case errors.Is(err, context.Canceled):
91 result = "canceled"
92 default:
93 result = "error"
94 }
95 MetricLookup.ObserveLabels(float64(time.Since(start))/float64(time.Second), pkg, typ, result)
96}
97
98func (r StrictResolver) WithPackage(name string) Resolver {
99 nr := r
100 nr.Pkg = name
101 return nr
102}
103
104func (r StrictResolver) resolver() Resolver {
105 if r.Resolver == nil {
106 return adns.DefaultResolver
107 }
108 return r.Resolver
109}
110
111func resolveErrorHint(err *error) {
112 e := *err
113 if e == nil {
114 return
115 }
116 dnserr, ok := e.(*adns.DNSError)
117 if !ok {
118 return
119 }
120 // If the dns server is not running, and it is one of the default/fallback IPs,
121 // hint at where to look.
122 if dnserr.IsTemporary && runtime.GOOS == "linux" && (dnserr.Server == "127.0.0.1:53" || dnserr.Server == "[::1]:53") && strings.HasSuffix(dnserr.Err, "connection refused") {
123 *err = fmt.Errorf("%w (hint: does /etc/resolv.conf point to a running nameserver? in case of systemd-resolved, see systemd-resolved.service(8); better yet, install a proper dnssec-verifying recursive resolver like unbound)", *err)
124 }
125}
126
127func (r StrictResolver) LookupPort(ctx context.Context, network, service string) (resp int, err error) {
128 start := time.Now()
129 defer func() {
130 metricLookupObserve(r.Pkg, "port", err, start)
131 r.log().WithContext(ctx).Debugx("dns lookup result", err,
132 slog.String("type", "port"),
133 slog.String("network", network),
134 slog.String("service", service),
135 slog.Int("resp", resp),
136 slog.Duration("duration", time.Since(start)),
137 )
138 }()
139 defer resolveErrorHint(&err)
140
141 resp, err = r.resolver().LookupPort(ctx, network, service)
142 return
143}
144
145func (r StrictResolver) LookupAddr(ctx context.Context, addr string) (resp []string, result adns.Result, err error) {
146 start := time.Now()
147 defer func() {
148 metricLookupObserve(r.Pkg, "addr", err, start)
149 r.log().WithContext(ctx).Debugx("dns lookup result", err,
150 slog.String("type", "addr"),
151 slog.String("addr", addr),
152 slog.Any("resp", resp),
153 slog.Bool("authentic", result.Authentic),
154 slog.Duration("duration", time.Since(start)),
155 )
156 }()
157 defer resolveErrorHint(&err)
158
159 resp, result, err = r.resolver().LookupAddr(ctx, addr)
160 // For addresses from /etc/hosts without dot, we add the missing trailing dot.
161 for i, s := range resp {
162 if !strings.HasSuffix(s, ".") {
163 resp[i] = s + "."
164 }
165 }
166 return
167}
168
169// LookupCNAME looks up a CNAME. Unlike "net" LookupCNAME, it returns a "not found"
170// error if there is no CNAME record.
171func (r StrictResolver) LookupCNAME(ctx context.Context, host string) (resp string, result adns.Result, err error) {
172 start := time.Now()
173 defer func() {
174 metricLookupObserve(r.Pkg, "cname", err, start)
175 r.log().WithContext(ctx).Debugx("dns lookup result", err,
176 slog.String("type", "cname"),
177 slog.String("host", host),
178 slog.String("resp", resp),
179 slog.Bool("authentic", result.Authentic),
180 slog.Duration("duration", time.Since(start)),
181 )
182 }()
183 defer resolveErrorHint(&err)
184
185 if !strings.HasSuffix(host, ".") {
186 return "", result, ErrRelativeDNSName
187 }
188 resp, result, err = r.resolver().LookupCNAME(ctx, host)
189 if err == nil && resp == host {
190 return "", result, &adns.DNSError{
191 Err: "no cname record",
192 Name: host,
193 Server: "",
194 IsNotFound: true,
195 }
196 }
197 return
198}
199
200func (r StrictResolver) LookupHost(ctx context.Context, host string) (resp []string, result adns.Result, err error) {
201 start := time.Now()
202 defer func() {
203 metricLookupObserve(r.Pkg, "host", err, start)
204 r.log().WithContext(ctx).Debugx("dns lookup result", err,
205 slog.String("type", "host"),
206 slog.String("host", host),
207 slog.Any("resp", resp),
208 slog.Bool("authentic", result.Authentic),
209 slog.Duration("duration", time.Since(start)),
210 )
211 }()
212 defer resolveErrorHint(&err)
213
214 if !strings.HasSuffix(host, ".") {
215 return nil, result, ErrRelativeDNSName
216 }
217 resp, result, err = r.resolver().LookupHost(ctx, host)
218 return
219}
220
221func (r StrictResolver) LookupIP(ctx context.Context, network, host string) (resp []net.IP, result adns.Result, err error) {
222 start := time.Now()
223 defer func() {
224 metricLookupObserve(r.Pkg, "ip", err, start)
225 r.log().WithContext(ctx).Debugx("dns lookup result", err,
226 slog.String("type", "ip"),
227 slog.String("network", network),
228 slog.String("host", host),
229 slog.Any("resp", resp),
230 slog.Bool("authentic", result.Authentic),
231 slog.Duration("duration", time.Since(start)),
232 )
233 }()
234 defer resolveErrorHint(&err)
235
236 if !strings.HasSuffix(host, ".") {
237 return nil, result, ErrRelativeDNSName
238 }
239 resp, result, err = r.resolver().LookupIP(ctx, network, host)
240 return
241}
242
243func (r StrictResolver) LookupIPAddr(ctx context.Context, host string) (resp []net.IPAddr, result adns.Result, err error) {
244 start := time.Now()
245 defer func() {
246 metricLookupObserve(r.Pkg, "ipaddr", err, start)
247 r.log().WithContext(ctx).Debugx("dns lookup result", err,
248 slog.String("type", "ipaddr"),
249 slog.String("host", host),
250 slog.Any("resp", resp),
251 slog.Bool("authentic", result.Authentic),
252 slog.Duration("duration", time.Since(start)),
253 )
254 }()
255 defer resolveErrorHint(&err)
256
257 if !strings.HasSuffix(host, ".") {
258 return nil, result, ErrRelativeDNSName
259 }
260 resp, result, err = r.resolver().LookupIPAddr(ctx, host)
261 return
262}
263
264func (r StrictResolver) LookupMX(ctx context.Context, name string) (resp []*net.MX, result adns.Result, err error) {
265 start := time.Now()
266 defer func() {
267 metricLookupObserve(r.Pkg, "mx", err, start)
268 r.log().WithContext(ctx).Debugx("dns lookup result", err,
269 slog.String("type", "mx"),
270 slog.String("name", name),
271 slog.Any("resp", resp),
272 slog.Bool("authentic", result.Authentic),
273 slog.Duration("duration", time.Since(start)),
274 )
275 }()
276 defer resolveErrorHint(&err)
277
278 if !strings.HasSuffix(name, ".") {
279 return nil, result, ErrRelativeDNSName
280 }
281 resp, result, err = r.resolver().LookupMX(ctx, name)
282 return
283}
284
285func (r StrictResolver) LookupNS(ctx context.Context, name string) (resp []*net.NS, result adns.Result, err error) {
286 start := time.Now()
287 defer func() {
288 metricLookupObserve(r.Pkg, "ns", err, start)
289 r.log().WithContext(ctx).Debugx("dns lookup result", err,
290 slog.String("type", "ns"),
291 slog.String("name", name),
292 slog.Any("resp", resp),
293 slog.Bool("authentic", result.Authentic),
294 slog.Duration("duration", time.Since(start)),
295 )
296 }()
297 defer resolveErrorHint(&err)
298
299 if !strings.HasSuffix(name, ".") {
300 return nil, result, ErrRelativeDNSName
301 }
302 resp, result, err = r.resolver().LookupNS(ctx, name)
303 return
304}
305
306func (r StrictResolver) LookupSRV(ctx context.Context, service, proto, name string) (resp0 string, resp1 []*net.SRV, result adns.Result, err error) {
307 start := time.Now()
308 defer func() {
309 metricLookupObserve(r.Pkg, "srv", err, start)
310 r.log().WithContext(ctx).Debugx("dns lookup result", err,
311 slog.String("type", "srv"),
312 slog.String("service", service),
313 slog.String("proto", proto),
314 slog.String("name", name),
315 slog.String("resp0", resp0),
316 slog.Any("resp1", resp1),
317 slog.Bool("authentic", result.Authentic),
318 slog.Duration("duration", time.Since(start)),
319 )
320 }()
321 defer resolveErrorHint(&err)
322
323 if !strings.HasSuffix(name, ".") {
324 return "", nil, result, ErrRelativeDNSName
325 }
326 resp0, resp1, result, err = r.resolver().LookupSRV(ctx, service, proto, name)
327 return
328}
329
330func (r StrictResolver) LookupTXT(ctx context.Context, name string) (resp []string, result adns.Result, err error) {
331 start := time.Now()
332 defer func() {
333 metricLookupObserve(r.Pkg, "txt", err, start)
334 r.log().WithContext(ctx).Debugx("dns lookup result", err,
335 slog.String("type", "txt"),
336 slog.String("name", name),
337 slog.Any("resp", resp),
338 slog.Bool("authentic", result.Authentic),
339 slog.Duration("duration", time.Since(start)),
340 )
341 }()
342 defer resolveErrorHint(&err)
343
344 if !strings.HasSuffix(name, ".") {
345 return nil, result, ErrRelativeDNSName
346 }
347 resp, result, err = r.resolver().LookupTXT(ctx, name)
348 return
349}
350
351func (r StrictResolver) LookupTLSA(ctx context.Context, port int, protocol, host string) (resp []adns.TLSA, result adns.Result, err error) {
352 start := time.Now()
353 defer func() {
354 metricLookupObserve(r.Pkg, "tlsa", err, start)
355 r.log().WithContext(ctx).Debugx("dns lookup result", err,
356 slog.String("type", "tlsa"),
357 slog.Int("port", port),
358 slog.String("protocol", protocol),
359 slog.String("host", host),
360 slog.Any("resp", resp),
361 slog.Bool("authentic", result.Authentic),
362 slog.Duration("duration", time.Since(start)),
363 )
364 }()
365 defer resolveErrorHint(&err)
366
367 if !strings.HasSuffix(host, ".") {
368 return nil, result, ErrRelativeDNSName
369 }
370 resp, result, err = r.resolver().LookupTLSA(ctx, port, protocol, host)
371 return
372}
373