1package dns
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "slices"
8
9 "github.com/mjl-/adns"
10)
11
12// MockResolver is a Resolver used for testing.
13// Set DNS records in the fields, which map FQDNs (with trailing dot) to values.
14type MockResolver struct {
15 PTR map[string][]string
16 A map[string][]string
17 AAAA map[string][]string
18 TXT map[string][]string
19 MX map[string][]*net.MX
20 TLSA map[string][]adns.TLSA // Keys are e.g. _25._tcp.<host>.
21 CNAME map[string]string
22 Fail []string // Records of the form "type name", e.g. "cname localhost." that will return a servfail.
23 AllAuthentic bool // Default value for authentic in responses. Overridden with Authentic and Inauthentic
24 Authentic []string // Like Fail, but records that cause the response to be authentic.
25 Inauthentic []string // Like Authentic, but making response inauthentic.
26}
27
28type mockReq struct {
29 Type string // E.g. "cname", "txt", "mx", "ptr", etc.
30 Name string // Name of request. For TLSA, the full requested DNS name, e.g. _25._tcp.<host>.
31}
32
33func (mr mockReq) String() string {
34 return mr.Type + " " + mr.Name
35}
36
37var _ Resolver = MockResolver{}
38
39func (r MockResolver) result(ctx context.Context, mr mockReq) (string, adns.Result, error) {
40 result := adns.Result{Authentic: r.AllAuthentic}
41
42 if err := ctx.Err(); err != nil {
43 return "", result, err
44 }
45
46 updateAuthentic := func(mock string) {
47 if slices.Contains(r.Authentic, mock) {
48 result.Authentic = true
49 }
50 if slices.Contains(r.Inauthentic, mock) {
51 result.Authentic = false
52 }
53 }
54
55 for {
56 if slices.Contains(r.Fail, mr.String()) {
57 updateAuthentic(mr.String())
58 return mr.Name, adns.Result{}, r.servfail(mr.Name)
59 }
60
61 cname, ok := r.CNAME[mr.Name]
62 if !ok {
63 updateAuthentic(mr.String())
64 break
65 }
66 updateAuthentic("cname " + mr.Name)
67 if mr.Type == "cname" {
68 return mr.Name, result, nil
69 }
70 mr.Name = cname
71 }
72 return mr.Name, result, nil
73}
74
75func (r MockResolver) nxdomain(s string) error {
76 return &adns.DNSError{
77 Err: "no record",
78 Name: s,
79 Server: "mock",
80 IsNotFound: true,
81 }
82}
83
84func (r MockResolver) servfail(s string) error {
85 return &adns.DNSError{
86 Err: "temp error",
87 Name: s,
88 Server: "mock",
89 IsTemporary: true,
90 }
91}
92
93func (r MockResolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
94 if err := ctx.Err(); err != nil {
95 return 0, err
96 }
97 return net.LookupPort(network, service)
98}
99
100func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, adns.Result, error) {
101 mr := mockReq{"cname", name}
102 name, result, err := r.result(ctx, mr)
103 if err != nil {
104 return name, result, err
105 }
106 cname, ok := r.CNAME[name]
107 if !ok {
108 return cname, result, r.nxdomain(name)
109 }
110 return cname, result, nil
111}
112
113func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, adns.Result, error) {
114 mr := mockReq{"ptr", ip}
115 _, result, err := r.result(ctx, mr)
116 if err != nil {
117 return nil, result, err
118 }
119 l, ok := r.PTR[ip]
120 if !ok {
121 return nil, result, r.nxdomain(ip)
122 }
123 return l, result, nil
124}
125
126func (r MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, adns.Result, error) {
127 mr := mockReq{"ns", name}
128 _, result, err := r.result(ctx, mr)
129 if err != nil {
130 return nil, result, err
131 }
132 return nil, result, r.servfail("ns not implemented")
133}
134
135func (r MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, adns.Result, error) {
136 xname := fmt.Sprintf("_%s._%s.%s", service, proto, name)
137 mr := mockReq{"srv", xname}
138 name, result, err := r.result(ctx, mr)
139 if err != nil {
140 return name, nil, result, err
141 }
142 return name, nil, result, r.servfail("srv not implemented")
143}
144
145func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, adns.Result, error) {
146 // todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
147 mr := mockReq{"ipaddr", host}
148 _, result, err := r.result(ctx, mr)
149 if err != nil {
150 return nil, result, err
151 }
152 addrs, result1, err := r.LookupHost(ctx, host)
153 result.Authentic = result.Authentic && result1.Authentic
154 if err != nil {
155 return nil, result, err
156 }
157 ips := make([]net.IPAddr, len(addrs))
158 for i, a := range addrs {
159 ip := net.ParseIP(a)
160 if ip == nil {
161 return nil, result, fmt.Errorf("malformed ip %q", a)
162 }
163 ips[i] = net.IPAddr{IP: ip}
164 }
165 return ips, result, nil
166}
167
168func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, adns.Result, error) {
169 // todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
170 mr := mockReq{"host", host}
171 _, result, err := r.result(ctx, mr)
172 if err != nil {
173 return nil, result, err
174 }
175 var addrs []string
176 addrs = append(addrs, r.A[host]...)
177 addrs = append(addrs, r.AAAA[host]...)
178 if len(addrs) == 0 {
179 return nil, result, r.nxdomain(host)
180 }
181 return addrs, result, nil
182}
183
184func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) {
185 mr := mockReq{"ip", host}
186 name, result, err := r.result(ctx, mr)
187 if err != nil {
188 return nil, result, err
189 }
190 var ips []net.IP
191 switch network {
192 case "ip", "ip4":
193 for _, ip := range r.A[name] {
194 ips = append(ips, net.ParseIP(ip))
195 }
196 }
197 switch network {
198 case "ip", "ip6":
199 for _, ip := range r.AAAA[name] {
200 ips = append(ips, net.ParseIP(ip))
201 }
202 }
203 if len(ips) == 0 {
204 return nil, result, r.nxdomain(host)
205 }
206 return ips, result, nil
207}
208
209func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) {
210 mr := mockReq{"mx", name}
211 name, result, err := r.result(ctx, mr)
212 if err != nil {
213 return nil, result, err
214 }
215 l, ok := r.MX[name]
216 if !ok {
217 return nil, result, r.nxdomain(name)
218 }
219 return l, result, nil
220}
221
222func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) {
223 mr := mockReq{"txt", name}
224 name, result, err := r.result(ctx, mr)
225 if err != nil {
226 return nil, result, err
227 }
228 l, ok := r.TXT[name]
229 if !ok {
230 return nil, result, r.nxdomain(name)
231 }
232 return l, result, nil
233}
234
235func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string, host string) ([]adns.TLSA, adns.Result, error) {
236 var name string
237 if port == 0 && protocol == "" {
238 name = host
239 } else {
240 name = fmt.Sprintf("_%d._%s.%s", port, protocol, host)
241 }
242 mr := mockReq{"tlsa", name}
243 name, result, err := r.result(ctx, mr)
244 if err != nil {
245 return nil, result, err
246 }
247 l, ok := r.TLSA[name]
248 if !ok {
249 return nil, result, r.nxdomain(name)
250 }
251 return l, result, nil
252}
253