1package smtpclient
2
3import (
4 "context"
5 "crypto/sha256"
6 "errors"
7 "fmt"
8 "net"
9 "reflect"
10 "testing"
11
12 "github.com/mjl-/adns"
13
14 "github.com/mjl-/mox/dns"
15 "github.com/mjl-/mox/mlog"
16)
17
18func domain(s string) dns.Domain {
19 d, err := dns.ParseDomain(s)
20 if err != nil {
21 panic("parse domain: " + err.Error())
22 }
23 return d
24}
25
26func ipdomain(s string) dns.IPDomain {
27 ip := net.ParseIP(s)
28 if ip != nil {
29 return dns.IPDomain{IP: ip}
30 }
31 d, err := dns.ParseDomain(s)
32 if err != nil {
33 panic(fmt.Sprintf("parse domain %q: %v", s, err))
34 }
35 return dns.IPDomain{Domain: d}
36}
37
38func ipdomains(s ...string) (l []dns.IPDomain) {
39 for _, e := range s {
40 l = append(l, ipdomain(e))
41 }
42 return
43}
44
45// Test basic MX lookup case, but also following CNAME, detecting CNAME loops and
46// having a CNAME limit, connecting directly to a host, and domain that does not
47// exist or has temporary error.
48func TestGatherDestinations(t *testing.T) {
49 ctxbg := context.Background()
50 log := mlog.New("smtpclient", nil)
51
52 resolver := dns.MockResolver{
53 MX: map[string][]*net.MX{
54 "basic.example.": {{Host: "mail.basic.example.", Pref: 10}},
55 "multimx.example.": {{Host: "mail1.multimx.example.", Pref: 10}, {Host: "mail2.multimx.example.", Pref: 10}},
56 "nullmx.example.": {{Host: ".", Pref: 10}},
57 "temperror-mx.example.": {{Host: "absent.example.", Pref: 10}},
58 },
59 A: map[string][]string{
60 "mail.basic.example": {"10.0.0.1"},
61 "justhost.example.": {"10.0.0.1"}, // No MX record for domain, only an A record.
62 "temperror-a.example.": {"10.0.0.1"},
63 },
64 AAAA: map[string][]string{
65 "justhost6.example.": {"2001:db8::1"}, // No MX record for domain, only an AAAA record.
66 },
67 CNAME: map[string]string{
68 "cname.example.": "basic.example.",
69 "cname-to-inauthentic.example.": "cnameinauthentic.example.",
70 "cnameinauthentic.example.": "basic.example.",
71 "cnameloop.example.": "cnameloop2.example.",
72 "cnameloop2.example.": "cnameloop.example.",
73 "danglingcname.example.": "absent.example.", // Points to missing name.
74 "temperror-cname.example.": "absent.example.",
75 },
76 Fail: []string{
77 "mx temperror-mx.example.",
78 "host temperror-a.example.",
79 "cname temperror-cname.example.",
80 },
81 Inauthentic: []string{"cname cnameinauthentic.example."},
82 }
83 for i := 0; i <= 16; i++ {
84 s := fmt.Sprintf("cnamelimit%d.example.", i)
85 next := fmt.Sprintf("cnamelimit%d.example.", i+1)
86 resolver.CNAME[s] = next
87 }
88
89 test := func(ipd dns.IPDomain, expHosts []dns.IPDomain, expDomain dns.Domain, expPerm, expAuthic, expExpAuthic bool, expErr error) {
90 t.Helper()
91
92 _, authic, authicExp, ed, hosts, perm, err := GatherDestinations(ctxbg, log.Logger, resolver, ipd)
93 if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
94 // todo: could also check the individual errors? code currently does not have structured errors.
95 t.Fatalf("gather hosts: %v, expected %v", err, expErr)
96 }
97 if err != nil {
98 return
99 }
100 if !reflect.DeepEqual(hosts, expHosts) || ed != expDomain || perm != expPerm || authic != expAuthic || authicExp != expExpAuthic {
101 t.Fatalf("got hosts %#v, effectiveDomain %#v, permanent %#v, authic %v %v, expected %#v %#v %#v %v %v", hosts, ed, perm, authic, authicExp, expHosts, expDomain, expPerm, expAuthic, expExpAuthic)
102 }
103 }
104
105 var zerodom dns.Domain
106
107 for i := 0; i < 2; i++ {
108 authic := i == 1
109 resolver.AllAuthentic = authic
110 // Basic with simple MX.
111 test(ipdomain("basic.example"), ipdomains("mail.basic.example"), domain("basic.example"), false, authic, authic, nil)
112 test(ipdomain("multimx.example"), ipdomains("mail1.multimx.example", "mail2.multimx.example"), domain("multimx.example"), false, authic, authic, nil)
113 // Only an A record.
114 test(ipdomain("justhost.example"), ipdomains("justhost.example"), domain("justhost.example"), false, authic, authic, nil)
115 // Only an AAAA record.
116 test(ipdomain("justhost6.example"), ipdomains("justhost6.example"), domain("justhost6.example"), false, authic, authic, nil)
117 // Follow CNAME.
118 test(ipdomain("cname.example"), ipdomains("mail.basic.example"), domain("basic.example"), false, authic, authic, nil)
119 // No MX/CNAME, non-existence of host will be found out later.
120 test(ipdomain("absent.example"), ipdomains("absent.example"), domain("absent.example"), false, authic, authic, nil)
121 // Followed CNAME, has no MX, non-existence of host will be found out later.
122 test(ipdomain("danglingcname.example"), ipdomains("absent.example"), domain("absent.example"), false, authic, authic, nil)
123 test(ipdomain("cnamelimit1.example"), nil, zerodom, true, authic, authic, errCNAMELimit)
124 test(ipdomain("cnameloop.example"), nil, zerodom, true, authic, authic, errCNAMELoop)
125 test(ipdomain("nullmx.example"), nil, zerodom, true, authic, authic, errNoMail)
126 test(ipdomain("temperror-mx.example"), nil, zerodom, false, authic, authic, errDNS)
127 test(ipdomain("temperror-cname.example"), nil, zerodom, false, authic, authic, errDNS)
128 }
129
130 test(ipdomain("10.0.0.1"), ipdomains("10.0.0.1"), zerodom, false, false, false, nil)
131 test(ipdomain("cnameinauthentic.example"), ipdomains("mail.basic.example"), domain("basic.example"), false, false, false, nil)
132 test(ipdomain("cname-to-inauthentic.example"), ipdomains("mail.basic.example"), domain("basic.example"), false, true, false, nil)
133}
134
135func TestGatherIPs(t *testing.T) {
136 ctxbg := context.Background()
137 log := mlog.New("smtpclient", nil)
138
139 resolver := dns.MockResolver{
140 A: map[string][]string{
141 "host1.example.": {"10.0.0.1"},
142 "host2.example.": {"10.0.0.2"},
143 "temperror-a.example.": {"10.0.0.3"},
144 },
145 AAAA: map[string][]string{
146 "host2.example.": {"2001:db8::1"},
147 },
148 CNAME: map[string]string{
149 "cname1.example.": "host1.example.",
150 "cname-to-inauthentic.example.": "cnameinauthentic.example.",
151 "cnameinauthentic.example.": "host1.example.",
152 "cnameloop.example.": "cnameloop2.example.",
153 "cnameloop2.example.": "cnameloop.example.",
154 "danglingcname.example.": "absent.example.", // Points to missing name.
155 "temperror-cname.example.": "absent.example.",
156 },
157 Fail: []string{
158 "ip temperror-a.example.",
159 "cname temperror-cname.example.",
160 },
161 Inauthentic: []string{"cname cnameinauthentic.example."},
162 }
163
164 test := func(host dns.IPDomain, expAuthic, expAuthicExp bool, expHostExp dns.Domain, expIPs []net.IP, expErr any, network string) {
165 t.Helper()
166
167 authic, authicExp, hostExp, ips, _, err := GatherIPs(ctxbg, log.Logger, resolver, network, host, nil)
168 if (err == nil) != (expErr == nil) || err != nil && !(errors.Is(err, expErr.(error)) || errors.As(err, &expErr)) {
169 // todo: could also check the individual errors?
170 t.Fatalf("gather hosts: %v, expected %v", err, expErr)
171 }
172 if err != nil {
173 return
174 }
175 if expHostExp == zerohost {
176 expHostExp = host.Domain
177 }
178 if authic != expAuthic || authicExp != expAuthicExp || hostExp != expHostExp || !reflect.DeepEqual(ips, expIPs) {
179 t.Fatalf("got authic %v %v, host %v, ips %v, expected %v %v %v %v", authic, authicExp, hostExp, ips, expAuthic, expAuthicExp, expHostExp, expIPs)
180 }
181 }
182
183 ips := func(l ...string) (r []net.IP) {
184 for _, s := range l {
185 r = append(r, net.ParseIP(s))
186 }
187 return r
188 }
189
190 for i := 0; i < 2; i++ {
191 authic := i == 1
192 resolver.AllAuthentic = authic
193
194 test(ipdomain("host1.example"), authic, authic, zerohost, ips("10.0.0.1"), nil, "ip")
195 test(ipdomain("host1.example"), authic, authic, zerohost, ips("10.0.0.1"), nil, "ip4")
196 test(ipdomain("host1.example"), authic, authic, zerohost, nil, &adns.DNSError{}, "ip6")
197 test(ipdomain("host2.example"), authic, authic, zerohost, ips("10.0.0.2", "2001:db8::1"), nil, "ip")
198 test(ipdomain("host2.example"), authic, authic, zerohost, ips("10.0.0.2"), nil, "ip4")
199 test(ipdomain("host2.example"), authic, authic, zerohost, ips("2001:db8::1"), nil, "ip6")
200 test(ipdomain("cname-to-inauthentic.example"), authic, false, domain("host1.example"), ips("10.0.0.1"), nil, "ip")
201 test(ipdomain("cnameloop.example"), authic, authic, zerohost, nil, errCNAMELimit, "ip")
202 test(ipdomain("bogus.example"), authic, authic, zerohost, nil, &adns.DNSError{}, "ip")
203 test(ipdomain("danglingcname.example"), authic, authic, zerohost, nil, &adns.DNSError{}, "ip")
204 test(ipdomain("temperror-a.example"), authic, authic, zerohost, nil, &adns.DNSError{}, "ip")
205 test(ipdomain("temperror-cname.example"), authic, authic, zerohost, nil, &adns.DNSError{}, "ip")
206
207 }
208 test(ipdomain("cnameinauthentic.example"), false, false, domain("host1.example"), ips("10.0.0.1"), nil, "ip")
209 test(ipdomain("cname-to-inauthentic.example"), true, false, domain("host1.example"), ips("10.0.0.1"), nil, "ip")
210}
211
212func TestGatherTLSA(t *testing.T) {
213 ctxbg := context.Background()
214 log := mlog.New("smtpclient", nil)
215
216 record := func(usage, selector, matchType uint8) adns.TLSA {
217 return adns.TLSA{
218 Usage: adns.TLSAUsage(usage),
219 Selector: adns.TLSASelector(selector),
220 MatchType: adns.TLSAMatchType(matchType),
221 CertAssoc: make([]byte, sha256.Size), // Assume sha256.
222 }
223 }
224 records := func(l ...adns.TLSA) []adns.TLSA {
225 return l
226 }
227
228 record0 := record(3, 1, 1)
229 list0 := records(record0)
230 record1 := record(3, 0, 1)
231 list1 := records(record1)
232
233 resolver := dns.MockResolver{
234 TLSA: map[string][]adns.TLSA{
235 "_25._tcp.host0.example.": list0,
236 "_25._tcp.host1.example.": list1,
237 "_25._tcp.inauthentic.example.": list1,
238 "_25._tcp.temperror-cname.example.": list1,
239 },
240 CNAME: map[string]string{
241 "_25._tcp.cname.example.": "_25._tcp.host1.example.",
242 "_25._tcp.cnameloop.example.": "_25._tcp.cnameloop2.example.",
243 "_25._tcp.cnameloop2.example.": "_25._tcp.cnameloop.example.",
244 "_25._tcp.cname-to-inauthentic.example.": "_25._tcp.cnameinauthentic.example.",
245 "_25._tcp.cnameinauthentic.example.": "_25._tcp.host1.example.",
246 "_25._tcp.danglingcname.example.": "_25._tcp.absent.example.", // Points to missing name.
247 },
248 Fail: []string{
249 "cname _25._tcp.temperror-cname.example.",
250 },
251 Inauthentic: []string{
252 "cname _25._tcp.cnameinauthentic.example.",
253 "tlsa _25._tcp.inauthentic.example.",
254 },
255 }
256
257 test := func(host dns.Domain, expandedAuthentic bool, expandedHost dns.Domain, expDANERequired bool, expRecords []adns.TLSA, expBaseDom dns.Domain, expErr any) {
258 t.Helper()
259
260 daneReq, records, baseDom, err := GatherTLSA(ctxbg, log.Logger, resolver, host, expandedAuthentic, expandedHost)
261 if (err == nil) != (expErr == nil) || err != nil && !(errors.Is(err, expErr.(error)) || errors.As(err, &expErr)) {
262 // todo: could also check the individual errors?
263 t.Fatalf("gather tlsa: %v, expected %v", err, expErr)
264 }
265 if daneReq != expDANERequired {
266 t.Fatalf("got daneRequired %v, expected %v", daneReq, expDANERequired)
267 }
268 if err != nil {
269 return
270 }
271 if !reflect.DeepEqual(records, expRecords) || baseDom != expBaseDom {
272 t.Fatalf("got records, baseDomain %v %v, expected %v %v", records, baseDom, expRecords, expBaseDom)
273 }
274 }
275
276 resolver.AllAuthentic = true
277 test(domain("host1.example"), false, domain("host1.example"), true, list1, domain("host1.example"), nil)
278 test(domain("host1.example"), true, domain("host1.example"), true, list1, domain("host1.example"), nil)
279 test(domain("host0.example"), true, domain("host1.example"), true, list1, domain("host1.example"), nil)
280 test(domain("host0.example"), false, domain("host1.example"), true, list0, domain("host0.example"), nil)
281
282 // CNAME for TLSA at cname.example should be followed.
283 test(domain("host0.example"), true, domain("cname.example"), true, list1, domain("cname.example"), nil)
284 // TLSA records at original domain should be followed.
285 test(domain("host0.example"), false, domain("cname.example"), true, list0, domain("host0.example"), nil)
286
287 test(domain("cnameloop.example"), false, domain("cnameloop.example"), true, nil, zerohost, errCNAMELimit)
288
289 test(domain("host0.example"), false, domain("inauthentic.example"), true, list0, domain("host0.example"), nil)
290 test(domain("inauthentic.example"), false, domain("inauthentic.example"), false, nil, domain("inauthentic.example"), nil)
291 test(domain("temperror-cname.example"), false, domain("temperror-cname.example"), true, nil, domain("temperror-cname.example"), &adns.DNSError{})
292
293 test(domain("host1.example"), true, domain("cname-to-inauthentic.example"), true, list1, domain("host1.example"), nil)
294 test(domain("host1.example"), true, domain("danglingcname.example"), true, list1, domain("host1.example"), nil)
295 test(domain("danglingcname.example"), true, domain("danglingcname.example"), false, nil, domain("danglingcname.example"), nil)
296}
297
298func TestGatherTLSANames(t *testing.T) {
299 a, b, c, d := domain("nexthop.example"), domain("nexthopexpanded.example"), domain("base.example"), domain("baseexpanded.example")
300 test := func(haveMX, nexthopExpAuth, tlsabaseExpAuth bool, expDoms ...dns.Domain) {
301 t.Helper()
302 doms := GatherTLSANames(haveMX, nexthopExpAuth, tlsabaseExpAuth, a, b, c, d)
303 if !reflect.DeepEqual(doms, expDoms) {
304 t.Fatalf("got domains %v, expected %v", doms, expDoms)
305 }
306 }
307
308 test(false, false, false, c)
309 test(false, false, true, d, c)
310 test(true, true, true, d, c, a, b)
311 test(true, true, false, c, a, b)
312 test(true, false, false, a)
313}
314