1package main
2
3/*
4note: these testdata paths are not in the repo, you should gather some of your
5own ham/spam emails.
6
7./mox junk train testdata/train/ham testdata/train/spam
8./mox junk train -sent-dir testdata/sent testdata/train/ham testdata/train/spam
9./mox junk check 'testdata/check/ham/mail1'
10./mox junk test testdata/check/ham testdata/check/spam
11./mox junk analyze testdata/train/ham testdata/train/spam
12./mox junk analyze -top-words 10 -train-ratio 0.5 -spam-threshold 0.85 -max-power 0.01 -sent-dir testdata/sent testdata/train/ham testdata/train/spam
13./mox junk play -top-words 10 -train-ratio 0.5 -spam-threshold 0.85 -max-power 0.01 -sent-dir testdata/sent testdata/train/ham testdata/train/spam
14*/
15
16import (
17 "context"
18 "flag"
19 "fmt"
20 "log"
21 mathrand "math/rand"
22 "os"
23 "path/filepath"
24 "sort"
25 "time"
26
27 "github.com/mjl-/mox/junk"
28 "github.com/mjl-/mox/message"
29 "github.com/mjl-/mox/mlog"
30 "github.com/mjl-/mox/mox-"
31)
32
33type junkArgs struct {
34 params junk.Params
35 spamThreshold float64
36 trainRatio float64
37 seed bool
38 sentDir string
39 databasePath, bloomfilterPath string
40 debug bool
41}
42
43func (a junkArgs) SetLogLevel() {
44 mox.Conf.Log[""] = mlog.LevelInfo
45 if a.debug {
46 mox.Conf.Log[""] = mlog.LevelDebug
47 }
48 mlog.SetConfig(mox.Conf.Log)
49}
50
51func junkFlags(fs *flag.FlagSet) (a junkArgs) {
52 fs.BoolVar(&a.params.Onegrams, "one-grams", false, "use 1-grams, i.e. single words, for scoring")
53 fs.BoolVar(&a.params.Twograms, "two-grams", true, "use 2-grams, i.e. word pairs, for scoring")
54 fs.BoolVar(&a.params.Threegrams, "three-grams", false, "use 3-grams, i.e. word triplets, for scoring")
55 fs.Float64Var(&a.params.MaxPower, "max-power", 0.05, "maximum word power, e.g. min 0.05/max 0.95")
56 fs.Float64Var(&a.params.IgnoreWords, "ignore-words", 0.1, "ignore words with ham/spaminess within this distance from 0.5")
57 fs.IntVar(&a.params.TopWords, "top-words", 10, "number of top spam and number of top ham words from email to use")
58 fs.IntVar(&a.params.RareWords, "rare-words", 1, "words are rare if encountered this number during training, and skipped for scoring")
59 fs.BoolVar(&a.debug, "debug", false, "print debug logging when calculating spam probability")
60
61 fs.Float64Var(&a.spamThreshold, "spam-threshold", 0.95, "probability where message is seen as spam")
62 fs.Float64Var(&a.trainRatio, "train-ratio", 0.5, "part of data to use for training versus analyzing (for analyze only)")
63 fs.StringVar(&a.sentDir, "sent-dir", "", "directory with sent mails, for training")
64 fs.BoolVar(&a.seed, "seed", false, "seed prng before analysis")
65 fs.StringVar(&a.databasePath, "dbpath", "filter.db", "database file for ham/spam words")
66 fs.StringVar(&a.bloomfilterPath, "bloompath", "filter.bloom", "bloom filter for ignoring unique strings")
67
68 return
69}
70
71func listDir(dir string) (l []string) {
72 files, err := os.ReadDir(dir)
73 xcheckf(err, "listing directory %q", dir)
74 for _, f := range files {
75 l = append(l, f.Name())
76 }
77 return l
78}
79
80func must(f *junk.Filter, err error) *junk.Filter {
81 xcheckf(err, "filter")
82 return f
83}
84
85func cmdJunkTrain(c *cmd) {
86 c.unlisted = true
87 c.params = "hamdir spamdir"
88 c.help = "Train a junk filter with messages from hamdir and spamdir."
89 a := junkFlags(c.flag)
90 args := c.Parse()
91 if len(args) != 2 {
92 c.Usage()
93 }
94 a.SetLogLevel()
95
96 f := must(junk.NewFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath))
97 defer func() {
98 if err := f.Close(); err != nil {
99 log.Printf("closing junk filter: %v", err)
100 }
101 }()
102
103 hamFiles := listDir(args[0])
104 spamFiles := listDir(args[1])
105 var sentFiles []string
106 if a.sentDir != "" {
107 sentFiles = listDir(a.sentDir)
108 }
109
110 err := f.TrainDirs(args[0], a.sentDir, args[1], hamFiles, sentFiles, spamFiles)
111 xcheckf(err, "train")
112}
113
114func cmdJunkCheck(c *cmd) {
115 c.unlisted = true
116 c.params = "mailfile"
117 c.help = "Check an email message against a junk filter, printing the probability of spam on a scale from 0 to 1."
118 a := junkFlags(c.flag)
119 args := c.Parse()
120 if len(args) != 1 {
121 c.Usage()
122 }
123 a.SetLogLevel()
124
125 f := must(junk.OpenFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath, false))
126 defer func() {
127 if err := f.Close(); err != nil {
128 log.Printf("closing junk filter: %v", err)
129 }
130 }()
131
132 prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), args[0])
133 xcheckf(err, "testing mail")
134
135 fmt.Printf("%.6f\n", prob)
136}
137
138func cmdJunkTest(c *cmd) {
139 c.unlisted = true
140 c.params = "hamdir spamdir"
141 c.help = "Check a directory with hams and one with spams against the junk filter, and report the success ratio."
142 a := junkFlags(c.flag)
143 args := c.Parse()
144 if len(args) != 2 {
145 c.Usage()
146 }
147 a.SetLogLevel()
148
149 f := must(junk.OpenFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath, false))
150 defer func() {
151 if err := f.Close(); err != nil {
152 log.Printf("closing junk filter: %v", err)
153 }
154 }()
155
156 testDir := func(dir string, ham bool) (int, int) {
157 ok, bad := 0, 0
158 files, err := os.ReadDir(dir)
159 xcheckf(err, "readdir %q", dir)
160 for _, fi := range files {
161 path := filepath.Join(dir, fi.Name())
162 prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path)
163 if err != nil {
164 log.Printf("classify message %q: %s", path, err)
165 continue
166 }
167 if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold {
168 ok++
169 } else {
170 bad++
171 }
172 if ham && prob > a.spamThreshold {
173 fmt.Printf("ham %q: %.4f\n", path, prob)
174 }
175 if !ham && prob < a.spamThreshold {
176 fmt.Printf("spam %q: %.4f\n", path, prob)
177 }
178 }
179 return ok, bad
180 }
181
182 nhamok, nhambad := testDir(args[0], true)
183 nspamok, nspambad := testDir(args[1], false)
184 fmt.Printf("total ham, ok %d, bad %d\n", nhamok, nhambad)
185 fmt.Printf("total spam, ok %d, bad %d\n", nspamok, nspambad)
186 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
187 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
188 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
189}
190
191func cmdJunkAnalyze(c *cmd) {
192 c.unlisted = true
193 c.params = "hamdir spamdir"
194 c.help = `Analyze a directory with ham messages and one with spam messages.
195
196A part of the messages is used for training, and remaining for testing. The
197messages are shuffled, with optional random seed.`
198 a := junkFlags(c.flag)
199 args := c.Parse()
200 if len(args) != 2 {
201 c.Usage()
202 }
203 a.SetLogLevel()
204
205 f := must(junk.NewFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath))
206 defer func() {
207 if err := f.Close(); err != nil {
208 log.Printf("closing junk filter: %v", err)
209 }
210 }()
211
212 hamDir := args[0]
213 spamDir := args[1]
214 hamFiles := listDir(hamDir)
215 spamFiles := listDir(spamDir)
216
217 var seed int64
218 if a.seed {
219 seed = time.Now().UnixMilli()
220 }
221 // Still at math/rand (v1 instead of v2) for potential comparison to earlier test results.
222 rand := mathrand.New(mathrand.NewSource(seed))
223
224 shuffle := func(l []string) {
225 count := len(l)
226 for i := range l {
227 n := rand.Intn(count)
228 l[i], l[n] = l[n], l[i]
229 }
230 }
231
232 shuffle(hamFiles)
233 shuffle(spamFiles)
234
235 ntrainham := int(a.trainRatio * float64(len(hamFiles)))
236 ntrainspam := int(a.trainRatio * float64(len(spamFiles)))
237
238 trainHam := hamFiles[:ntrainham]
239 trainSpam := spamFiles[:ntrainspam]
240 testHam := hamFiles[ntrainham:]
241 testSpam := spamFiles[ntrainspam:]
242
243 var trainSent []string
244 if a.sentDir != "" {
245 trainSent = listDir(a.sentDir)
246 }
247
248 err := f.TrainDirs(hamDir, a.sentDir, spamDir, trainHam, trainSent, trainSpam)
249 xcheckf(err, "train")
250
251 testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) {
252 for _, name := range files {
253 path := filepath.Join(dir, name)
254 prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path)
255 if err != nil {
256 // log.Infof("%s: %s", path, err)
257 malformed++
258 continue
259 }
260 if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold {
261 ok++
262 } else {
263 bad++
264 }
265 if ham && prob > a.spamThreshold {
266 fmt.Printf("ham %q: %.4f\n", path, prob)
267 }
268 if !ham && prob < a.spamThreshold {
269 fmt.Printf("spam %q: %.4f\n", path, prob)
270 }
271 }
272 return
273 }
274
275 nhamok, nhambad, nmalformedham := testDir(args[0], testHam, true)
276 nspamok, nspambad, nmalformedspam := testDir(args[1], testSpam, false)
277 fmt.Printf("training done, nham %d, nsent %d, nspam %d\n", ntrainham, len(trainSent), ntrainspam)
278 fmt.Printf("total ham, ok %d, bad %d, malformed %d\n", nhamok, nhambad, nmalformedham)
279 fmt.Printf("total spam, ok %d, bad %d, malformed %d\n", nspamok, nspambad, nmalformedspam)
280 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
281 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
282 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
283}
284
285func cmdJunkPlay(c *cmd) {
286 c.unlisted = true
287 c.params = "hamdir spamdir"
288 c.help = "Play messages from ham and spam directory according to their time of arrival and report on junk filter performance."
289 a := junkFlags(c.flag)
290 args := c.Parse()
291 if len(args) != 2 {
292 c.Usage()
293 }
294 a.SetLogLevel()
295
296 f := must(junk.NewFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath))
297 defer func() {
298 if err := f.Close(); err != nil {
299 log.Printf("closing junk filter: %v", err)
300 }
301 }()
302
303 // We'll go through all emails to find their dates.
304 type msg struct {
305 dir, filename string
306 ham, sent bool
307 t time.Time
308 }
309 var msgs []msg
310
311 var nbad, nnodate, nham, nspam, nsent int
312
313 scanDir := func(dir string, ham, sent bool) {
314 for _, name := range listDir(dir) {
315 path := filepath.Join(dir, name)
316 mf, err := os.Open(path)
317 xcheckf(err, "open %q", path)
318 fi, err := mf.Stat()
319 xcheckf(err, "stat %q", path)
320 p, err := message.EnsurePart(c.log.Logger, false, mf, fi.Size())
321 if err != nil {
322 nbad++
323 if err := mf.Close(); err != nil {
324 log.Printf("closing message file: %v", err)
325 }
326 continue
327 }
328 if p.Envelope.Date.IsZero() {
329 nnodate++
330 if err := mf.Close(); err != nil {
331 log.Printf("closing message file: %v", err)
332 }
333 continue
334 }
335 if err := mf.Close(); err != nil {
336 log.Printf("closing message file: %v", err)
337 }
338 msgs = append(msgs, msg{dir, name, ham, sent, p.Envelope.Date})
339 if sent {
340 nsent++
341 } else if ham {
342 nham++
343 } else {
344 nspam++
345 }
346 }
347 }
348
349 hamDir := args[0]
350 spamDir := args[1]
351 scanDir(hamDir, true, false)
352 scanDir(spamDir, false, false)
353 if a.sentDir != "" {
354 scanDir(a.sentDir, true, true)
355 }
356
357 // Sort the messages, earliest first.
358 sort.Slice(msgs, func(i, j int) bool {
359 return msgs[i].t.Before(msgs[j].t)
360 })
361
362 // Play all messages as if they are coming in. We predict their spaminess, check if
363 // we are right. And we train the system with the result.
364 var nhamok, nhambad, nspamok, nspambad int
365
366 play := func(msg msg) {
367 var words map[string]struct{}
368 path := filepath.Join(msg.dir, msg.filename)
369 if !msg.sent {
370 var prob float64
371 var err error
372 prob, words, _, _, err = f.ClassifyMessagePath(context.Background(), path)
373 if err != nil {
374 nbad++
375 return
376 }
377 if msg.ham {
378 if prob < a.spamThreshold {
379 nhamok++
380 } else {
381 nhambad++
382 }
383 } else {
384 if prob > a.spamThreshold {
385 nspamok++
386 } else {
387 nspambad++
388 }
389 }
390 } else {
391 mf, err := os.Open(path)
392 xcheckf(err, "open %q", path)
393 defer func() {
394 if err := mf.Close(); err != nil {
395 log.Printf("closing message file: %v", err)
396 }
397 }()
398 fi, err := mf.Stat()
399 xcheckf(err, "stat %q", path)
400 p, err := message.EnsurePart(c.log.Logger, false, mf, fi.Size())
401 if err != nil {
402 log.Printf("bad sent message %q: %s", path, err)
403 return
404 }
405
406 words, err = f.ParseMessage(p)
407 if err != nil {
408 log.Printf("bad sent message %q: %s", path, err)
409 return
410 }
411 }
412
413 if err := f.Train(context.Background(), msg.ham, words); err != nil {
414 log.Printf("train: %s", err)
415 }
416 }
417
418 for _, m := range msgs {
419 play(m)
420 }
421
422 err := f.Save()
423 xcheckf(err, "saving filter")
424
425 fmt.Printf("completed, nham %d, nsent %d, nspam %d, nbad %d, nwithoutdate %d\n", nham, nsent, nspam, nbad, nnodate)
426 fmt.Printf("total ham, ok %d, bad %d\n", nhamok, nhambad)
427 fmt.Printf("total spam, ok %d, bad %d\n", nspamok, nspambad)
428 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
429 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
430 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
431}
432