1package store
2
3import (
4 "bytes"
5 "io"
6 "strings"
7 "unicode"
8 "unicode/utf8"
9
10 "github.com/mjl-/mox/message"
11 "github.com/mjl-/mox/mlog"
12)
13
14// WordSearch holds context for a search, with scratch buffers to prevent
15// allocations for each message.
16type WordSearch struct {
17 words, notWords [][]byte
18 searchBuf, keepBuf []byte
19}
20
21// PrepareWordSearch returns a search context that can be used to match multiple
22// messages (after each other, not concurrently).
23func PrepareWordSearch(words, notWords []string) WordSearch {
24 var wl, nwl [][]byte
25 for _, w := range words {
26 wl = append(wl, []byte(strings.ToLower(w)))
27 }
28 for _, w := range notWords {
29 nwl = append(nwl, []byte(strings.ToLower(w)))
30 }
31
32 keep := 0
33 for _, w := range words {
34 if len(w) > keep {
35 keep = len(w)
36 }
37 }
38 for _, w := range notWords {
39 if len(w) > keep {
40 keep = len(w)
41 }
42 }
43 keep += 6 // Max utf-8 character size.
44
45 bufSize := 8 * 1024
46 for bufSize/keep < 8 {
47 bufSize *= 2
48 }
49
50 keepBuf := make([]byte, keep)
51 searchBuf := make([]byte, bufSize)
52
53 return WordSearch{wl, nwl, searchBuf, keepBuf}
54}
55
56// MatchPart returns whether the part/mail message p matches the search.
57// The search terms are matched against content-transfer-decoded and
58// charset-decoded bodies and optionally headers.
59// HTML parts are currently treated as regular text, without parsing HTML.
60func (ws WordSearch) MatchPart(log mlog.Log, p *message.Part, headerToo bool) (bool, error) {
61 seen := map[int]bool{}
62 miss, err := ws.matchPart(log, p, headerToo, seen)
63 match := err == nil && !miss && len(seen) == len(ws.words)
64 return match, err
65}
66
67// If all words are seen, and we there are no not-words that force us to search
68// till the end, we know we have a match.
69func (ws WordSearch) isQuickHit(seen map[int]bool) bool {
70 return len(seen) == len(ws.words) && len(ws.notWords) == 0
71}
72
73// search a part as text and/or its subparts, recursively. Once we know we have
74// a miss, we stop (either due to not-word match or error). In case of
75// non-miss, the caller checks if there was a hit.
76func (ws WordSearch) matchPart(log mlog.Log, p *message.Part, headerToo bool, seen map[int]bool) (miss bool, rerr error) {
77 if headerToo {
78 miss, err := ws.searchReader(log, p.HeaderReader(), seen)
79 if miss || err != nil || ws.isQuickHit(seen) {
80 return miss, err
81 }
82 }
83
84 if len(p.Parts) == 0 {
85 if p.MediaType != "TEXT" {
86 // todo: for other types we could try to find a library for parsing and search in there too.
87 return false, nil
88 }
89 tp := p.ReaderUTF8OrBinary()
90 // todo: for html and perhaps other types, we could try to parse as text and filter on the text.
91 miss, err := ws.searchReader(log, tp, seen)
92 if miss || err != nil || ws.isQuickHit(seen) {
93 return miss, err
94 }
95 }
96 for _, pp := range p.Parts {
97 if pp.Message != nil {
98 if err := pp.SetMessageReaderAt(); err != nil {
99 return false, err
100 }
101 pp = *pp.Message
102 }
103 miss, err := ws.matchPart(log, &pp, headerToo, seen)
104 if miss || err != nil || ws.isQuickHit(seen) {
105 return miss, err
106 }
107 }
108 return false, nil
109}
110
111func (ws WordSearch) searchReader(log mlog.Log, r io.Reader, seen map[int]bool) (miss bool, rerr error) {
112 // We will be reading through the content, stopping as soon as we known an answer:
113 // when all words have been seen and there are no "not words" (true), or one "not
114 // word" has been seen (false). We use bytes.Contains to look for the words. We
115 // advance our buffer in largish chunks, keeping the end of the buffer the size of
116 // the largest word plus the max of an utf-8 character to account for words
117 // spanning chunks.
118
119 have := 0
120 for {
121 n, err := io.ReadFull(r, ws.searchBuf[have:])
122 if n > 0 {
123 have += n
124 }
125 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
126 return true, err
127 }
128 if err == nil {
129 copy(ws.keepBuf, ws.searchBuf[have-len(ws.keepBuf):])
130 }
131
132 lower := toLower(ws.searchBuf[:have])
133
134 for i, w := range ws.words {
135 if !seen[i] && bytes.Contains(lower, w) {
136 seen[i] = true
137 if len(seen) == len(ws.words) && len(ws.notWords) == 0 {
138 return false, nil
139 }
140 }
141 }
142 for _, w := range ws.notWords {
143 if bytes.Contains(lower, w) {
144 return true, nil
145 }
146 }
147 if err != nil {
148 // Must be EOF or UnexpectedEOF now.
149 break
150 }
151 copy(ws.searchBuf, ws.keepBuf)
152 have = len(ws.keepBuf)
153 }
154 return false, nil
155}
156
157// in-place lower-casing, only allocating a new slice when lower-case would become
158// larger. we replace RuneError (0xfffd) by byte value 0, because it would often
159// increase size, but we assume no one wants to match it.
160func toLower(buf []byte) []byte {
161 r := buf[:0]
162 copied := false
163 for i := 0; i < len(buf); {
164 if buf[i] < 0x80 {
165 b := buf[i]
166 if b >= 'A' && b <= 'Z' {
167 b += 0x20
168 }
169 r = append(r, b)
170 i++
171 continue
172 }
173 c, size := utf8.DecodeRune(buf[i:])
174 i += size
175 nc := unicode.ToLower(c)
176 if nc < 0 {
177 continue
178 }
179 if c == utf8.RuneError {
180 r = append(r, 0)
181 continue
182 }
183 nsize := utf8.RuneLen(nc)
184 // Take care not to overwrite the part of the buffer we still have to process.
185 if !copied && len(r)+nsize > i {
186 // eg Ⱥ 0x23a (2 bytes) to ⱥ 0x2c65 (3 bytes)
187 copied = true
188 nr := make([]byte, len(r), len(r)+nsize+len(buf)-i)
189 copy(nr, r)
190 nr = r
191 }
192 r = utf8.AppendRune(r, nc)
193 }
194 return r
195}
196