1package smtp
2
3import (
4 "bufio"
5 "errors"
6 "io"
7 "strings"
8 "testing"
9)
10
11func TestDataWrite(t *testing.T) {
12 checkBad := func(s string, expErr error) {
13 t.Helper()
14 if err := DataWrite(io.Discard, strings.NewReader(s)); err == nil || !errors.Is(err, expErr) {
15 t.Fatalf("got err %v, expected %v", err, expErr)
16 }
17 }
18
19 checkBad("bad", errMissingCRLF)
20 checkBad(".", errMissingCRLF)
21 checkBad("bare \r is bad\r\n", ErrCRLF)
22 checkBad("bare \n is bad\r\n", ErrCRLF)
23 checkBad("\n.\nis bad\r\n", ErrCRLF)
24 checkBad("\r.\ris bad\r\n", ErrCRLF)
25 checkBad("\r\n.\ris bad\r\n", ErrCRLF)
26 checkBad("\r\n.\nis bad\r\n", ErrCRLF)
27 checkBad("\n.\ris bad\r\n", ErrCRLF)
28 checkBad("\n.\r\nis bad\r\n", ErrCRLF)
29
30 check := func(msg, want string) {
31 t.Helper()
32 w := &strings.Builder{}
33 if err := DataWrite(w, strings.NewReader(msg)); err != nil {
34 t.Fatalf("writing smtp data: %s", err)
35 }
36 got := w.String()
37 if got != want {
38 t.Fatalf("got %q, expected %q, for msg %q", got, want, msg)
39 }
40 }
41
42 check("", ".\r\n")
43 check(".\r\n", "..\r\n.\r\n")
44 check("header: abc\r\n\r\nmessage\r\n", "header: abc\r\n\r\nmessage\r\n.\r\n")
45}
46
47func TestDataReader(t *testing.T) {
48 // Copy with a 1 byte buffer for reading.
49 smallCopy := func(d io.Writer, r io.Reader) (int, error) {
50 var wrote int
51 buf := make([]byte, 1)
52 for {
53 n, err := r.Read(buf)
54 if n > 0 {
55 nn, err := d.Write(buf)
56 if nn > 0 {
57 wrote += nn
58 }
59 if err != nil {
60 return wrote, err
61 }
62 }
63 if err == io.EOF {
64 break
65 } else if err != nil {
66 return wrote, err
67 }
68 }
69 return wrote, nil
70 }
71
72 check := func(data, want string, expErr error) {
73 t.Helper()
74
75 s := &strings.Builder{}
76 dr := NewDataReader(bufio.NewReader(strings.NewReader(data)))
77 if _, err := io.Copy(s, dr); err != nil {
78 if expErr == nil || !errors.Is(err, expErr) {
79 t.Fatalf("got err %v, expected %v", err, expErr)
80 }
81 } else if got := s.String(); got != want {
82 t.Fatalf("got %q, expected %q, for %q", got, want, data)
83 }
84
85 s = &strings.Builder{}
86 dr = NewDataReader(bufio.NewReader(strings.NewReader(data)))
87 if _, err := smallCopy(s, dr); err != nil {
88 if expErr == nil || !errors.Is(err, expErr) {
89 t.Fatalf("got err %v, expected %v", err, expErr)
90 }
91 } else if got := s.String(); got != want {
92 t.Fatalf("got %q, expected %q, for %q", got, want, data)
93 }
94 }
95
96 check("test\r\n.\r\n", "test\r\n", nil)
97 check(".\r\n", "", nil)
98 check(".test\r\n.\r\n", "test\r\n", nil) // Unnecessary dot, but valid in SMTP.
99 check("..test\r\n.\r\n", ".test\r\n", nil)
100
101 check("..test\ntest.\n\r\n.\r\n", ".test\ntest.\n\r\n", nil) // Bare newlines are allowed.
102 check("..test\ntest\n", "", io.ErrUnexpectedEOF) // Missing end-of-message.
103
104 // Bare \r is rejected.
105 check("bare \r is rejected\r\n.\r\n", "", ErrCRLF)
106 check("bad:\r.\ris rejected\r\n.\r\n", "", ErrCRLF)
107 check("bad:\r.\nis rejected\r\n.\r\n", "", ErrCRLF)
108
109 // Suspicious bare newlines around a dot are rejected.
110 check("bad:\n.\nis rejected\r\n.\r\n", "", ErrCRLF)
111 check("bad:\n.\r\nis rejected\r\n.\r\n", "", ErrCRLF)
112 check("bad:\r\n.\nis rejected\r\n.\r\n", "", ErrCRLF)
113
114 // Suspicious near-smtp-endings at start of message.
115 check(".\ris rejected\r\n.\r\n", "", ErrCRLF)
116 check(".\nis rejected\r\n.\r\n", "", ErrCRLF)
117 check("\n.\ris rejected\r\n.\r\n", "", ErrCRLF)
118 check("\r.\ris rejected\r\n.\r\n", "", ErrCRLF)
119 check("\n.\nis rejected\r\n.\r\n", "", ErrCRLF)
120 check("\r.\nis rejected\r\n.\r\n", "", ErrCRLF)
121 check("\r.\r\nis rejected\r\n.\r\n", "", ErrCRLF)
122 check("\n.\r\nis rejected\r\n.\r\n", "", ErrCRLF)
123 check("\r\n.\ris rejected\r\n.\r\n", "", ErrCRLF)
124 check("\r\n.\nis rejected\r\n.\r\n", "", ErrCRLF)
125
126 s := &strings.Builder{}
127 dr := NewDataReader(bufio.NewReader(strings.NewReader("no end")))
128 if _, err := io.Copy(s, dr); err != io.ErrUnexpectedEOF {
129 t.Fatalf("got err %v, expected io.ErrUnexpectedEOF", err)
130 }
131}
132
133func TestDataWriteLineBoundaries(t *testing.T) {
134 const valid = "Subject: test\r\n\r\nbody\r\n"
135 if err := DataWrite(io.Discard, &oneReader{[]byte(valid)}); err != nil {
136 t.Fatalf("data write: %v", err)
137 }
138}
139
140// oneReader returns data one byte at a time.
141type oneReader struct {
142 buf []byte
143}
144
145func (r *oneReader) Read(buf []byte) (int, error) {
146 if len(r.buf) == 0 {
147 return 0, io.EOF
148 }
149 if len(buf) == 0 {
150 return 0, nil
151 }
152 buf[0] = r.buf[0]
153 r.buf = r.buf[1:]
154 return 1, nil
155}
156