12// MockResolver is a Resolver used for testing.
13// Set DNS records in the fields, which map FQDNs (with trailing dot) to values.
14type MockResolver struct {
15 PTR map[string][]string
17 AAAA map[string][]string
18 TXT map[string][]string
19 MX map[string][]*net.MX
20 TLSA map[string][]adns.TLSA // Keys are e.g. _25._tcp.<host>.
21 CNAME map[string]string
22 Fail []string // Records of the form "type name", e.g. "cname localhost." that will return a servfail.
23 AllAuthentic bool // Default value for authentic in responses. Overridden with Authentic and Inauthentic
24 Authentic []string // Like Fail, but records that cause the response to be authentic.
25 Inauthentic []string // Like Authentic, but making response inauthentic.
29 Type string // E.g. "cname", "txt", "mx", "ptr", etc.
30 Name string // Name of request. For TLSA, the full requested DNS name, e.g. _25._tcp.<host>.
33func (mr mockReq) String() string {
34 return mr.Type + " " + mr.Name
37var _ Resolver = MockResolver{}
39func (r MockResolver) result(ctx context.Context, mr mockReq) (string, adns.Result, error) {
40 result := adns.Result{Authentic: r.AllAuthentic}
42 if err := ctx.Err(); err != nil {
43 return "", result, err
46 updateAuthentic := func(mock string) {
47 if slices.Contains(r.Authentic, mock) {
48 result.Authentic = true
50 if slices.Contains(r.Inauthentic, mock) {
51 result.Authentic = false
56 if slices.Contains(r.Fail, mr.String()) {
57 updateAuthentic(mr.String())
58 return mr.Name, adns.Result{}, r.servfail(mr.Name)
61 cname, ok := r.CNAME[mr.Name]
63 updateAuthentic(mr.String())
66 updateAuthentic("cname " + mr.Name)
67 if mr.Type == "cname" {
68 return mr.Name, result, nil
72 return mr.Name, result, nil
75func (r MockResolver) nxdomain(s string) error {
76 return &adns.DNSError{
84func (r MockResolver) servfail(s string) error {
85 return &adns.DNSError{
93func (r MockResolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
94 if err := ctx.Err(); err != nil {
97 return net.LookupPort(network, service)
100func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, adns.Result, error) {
101 mr := mockReq{"cname", name}
102 name, result, err := r.result(ctx, mr)
104 return name, result, err
106 cname, ok := r.CNAME[name]
108 return cname, result, r.nxdomain(name)
110 return cname, result, nil
113func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, adns.Result, error) {
114 mr := mockReq{"ptr", ip}
115 _, result, err := r.result(ctx, mr)
117 return nil, result, err
121 return nil, result, r.nxdomain(ip)
123 return l, result, nil
126func (r MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, adns.Result, error) {
127 mr := mockReq{"ns", name}
128 _, result, err := r.result(ctx, mr)
130 return nil, result, err
132 return nil, result, r.servfail("ns not implemented")
135func (r MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, adns.Result, error) {
136 xname := fmt.Sprintf("_%s._%s.%s", service, proto, name)
137 mr := mockReq{"srv", xname}
138 name, result, err := r.result(ctx, mr)
140 return name, nil, result, err
142 return name, nil, result, r.servfail("srv not implemented")
145func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, adns.Result, error) {
146 // todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
147 mr := mockReq{"ipaddr", host}
148 _, result, err := r.result(ctx, mr)
150 return nil, result, err
152 addrs, result1, err := r.LookupHost(ctx, host)
153 result.Authentic = result.Authentic && result1.Authentic
155 return nil, result, err
157 ips := make([]net.IPAddr, len(addrs))
158 for i, a := range addrs {
161 return nil, result, fmt.Errorf("malformed ip %q", a)
163 ips[i] = net.IPAddr{IP: ip}
165 return ips, result, nil
168func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, adns.Result, error) {
169 // todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
170 mr := mockReq{"host", host}
171 _, result, err := r.result(ctx, mr)
173 return nil, result, err
176 addrs = append(addrs, r.A[host]...)
177 addrs = append(addrs, r.AAAA[host]...)
179 return nil, result, r.nxdomain(host)
181 return addrs, result, nil
184func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) {
185 mr := mockReq{"ip", host}
186 name, result, err := r.result(ctx, mr)
188 return nil, result, err
193 for _, ip := range r.A[name] {
194 ips = append(ips, net.ParseIP(ip))
199 for _, ip := range r.AAAA[name] {
200 ips = append(ips, net.ParseIP(ip))
204 return nil, result, r.nxdomain(host)
206 return ips, result, nil
209func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) {
210 mr := mockReq{"mx", name}
211 name, result, err := r.result(ctx, mr)
213 return nil, result, err
217 return nil, result, r.nxdomain(name)
219 return l, result, nil
222func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) {
223 mr := mockReq{"txt", name}
224 name, result, err := r.result(ctx, mr)
226 return nil, result, err
230 return nil, result, r.nxdomain(name)
232 return l, result, nil
235func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string, host string) ([]adns.TLSA, adns.Result, error) {
237 if port == 0 && protocol == "" {
240 name = fmt.Sprintf("_%d._%s.%s", port, protocol, host)
242 mr := mockReq{"tlsa", name}
243 name, result, err := r.result(ctx, mr)
245 return nil, result, err
247 l, ok := r.TLSA[name]
249 return nil, result, r.nxdomain(name)
251 return l, result, nil