1package main
2
3/*
4note: these testdata paths are not in the repo, you should gather some of your
5own ham/spam emails.
6
7./mox junk train testdata/train/ham testdata/train/spam
8./mox junk train -sent-dir testdata/sent testdata/train/ham testdata/train/spam
9./mox junk check 'testdata/check/ham/mail1'
10./mox junk test testdata/check/ham testdata/check/spam
11./mox junk analyze testdata/train/ham testdata/train/spam
12./mox junk analyze -top-words 10 -train-ratio 0.5 -spam-threshold 0.85 -max-power 0.01 -sent-dir testdata/sent testdata/train/ham testdata/train/spam
13./mox junk play -top-words 10 -train-ratio 0.5 -spam-threshold 0.85 -max-power 0.01 -sent-dir testdata/sent testdata/train/ham testdata/train/spam
14*/
15
16import (
17 "context"
18 "flag"
19 "fmt"
20 "log"
21 mathrand "math/rand"
22 "os"
23 "path/filepath"
24 "sort"
25 "time"
26
27 "github.com/mjl-/mox/junk"
28 "github.com/mjl-/mox/message"
29 "github.com/mjl-/mox/mlog"
30 "github.com/mjl-/mox/mox-"
31)
32
33type junkArgs struct {
34 params junk.Params
35 spamThreshold float64
36 trainRatio float64
37 seed bool
38 sentDir string
39 databasePath, bloomfilterPath string
40 debug bool
41}
42
43func (a junkArgs) SetLogLevel() {
44 mox.Conf.Log[""] = mlog.LevelInfo
45 if a.debug {
46 mox.Conf.Log[""] = mlog.LevelDebug
47 }
48 mlog.SetConfig(mox.Conf.Log)
49}
50
51func junkFlags(fs *flag.FlagSet) (a junkArgs) {
52 fs.BoolVar(&a.params.Onegrams, "one-grams", false, "use 1-grams, i.e. single words, for scoring")
53 fs.BoolVar(&a.params.Twograms, "two-grams", true, "use 2-grams, i.e. word pairs, for scoring")
54 fs.BoolVar(&a.params.Threegrams, "three-grams", false, "use 3-grams, i.e. word triplets, for scoring")
55 fs.Float64Var(&a.params.MaxPower, "max-power", 0.05, "maximum word power, e.g. min 0.05/max 0.95")
56 fs.Float64Var(&a.params.IgnoreWords, "ignore-words", 0.1, "ignore words with ham/spaminess within this distance from 0.5")
57 fs.IntVar(&a.params.TopWords, "top-words", 10, "number of top spam and number of top ham words from email to use")
58 fs.IntVar(&a.params.RareWords, "rare-words", 1, "words are rare if encountered this number during training, and skipped for scoring")
59 fs.BoolVar(&a.debug, "debug", false, "print debug logging when calculating spam probability")
60
61 fs.Float64Var(&a.spamThreshold, "spam-threshold", 0.95, "probability where message is seen as spam")
62 fs.Float64Var(&a.trainRatio, "train-ratio", 0.5, "part of data to use for training versus analyzing (for analyze only)")
63 fs.StringVar(&a.sentDir, "sent-dir", "", "directory with sent mails, for training")
64 fs.BoolVar(&a.seed, "seed", false, "seed prng before analysis")
65 fs.StringVar(&a.databasePath, "dbpath", "filter.db", "database file for ham/spam words")
66 fs.StringVar(&a.bloomfilterPath, "bloompath", "filter.bloom", "bloom filter for ignoring unique strings")
67
68 return
69}
70
71func listDir(dir string) (l []string) {
72 files, err := os.ReadDir(dir)
73 xcheckf(err, "listing directory %q", dir)
74 for _, f := range files {
75 l = append(l, f.Name())
76 }
77 return l
78}
79
80func must(f *junk.Filter, err error) *junk.Filter {
81 xcheckf(err, "filter")
82 return f
83}
84
85func cmdJunkTrain(c *cmd) {
86 c.unlisted = true
87 c.params = "hamdir spamdir"
88 c.help = "Train a junk filter with messages from hamdir and spamdir."
89 a := junkFlags(c.flag)
90 args := c.Parse()
91 if len(args) != 2 {
92 c.Usage()
93 }
94 a.SetLogLevel()
95
96 f := must(junk.NewFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath))
97 defer func() {
98 if err := f.Close(); err != nil {
99 log.Printf("closing junk filter: %v", err)
100 }
101 }()
102
103 hamFiles := listDir(args[0])
104 spamFiles := listDir(args[1])
105 var sentFiles []string
106 if a.sentDir != "" {
107 sentFiles = listDir(a.sentDir)
108 }
109
110 err := f.TrainDirs(args[0], a.sentDir, args[1], hamFiles, sentFiles, spamFiles)
111 xcheckf(err, "train")
112}
113
114func cmdJunkCheck(c *cmd) {
115 c.unlisted = true
116 c.params = "mailfile"
117 c.help = "Check an email message against a junk filter, printing the probability of spam on a scale from 0 to 1."
118 a := junkFlags(c.flag)
119 args := c.Parse()
120 if len(args) != 1 {
121 c.Usage()
122 }
123 a.SetLogLevel()
124
125 f := must(junk.OpenFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath, false))
126 defer func() {
127 if err := f.Close(); err != nil {
128 log.Printf("closing junk filter: %v", err)
129 }
130 }()
131
132 result, err := f.ClassifyMessagePath(context.Background(), args[0])
133 xcheckf(err, "testing mail")
134
135 sig := "significant"
136 if !result.Significant {
137 sig = "not significant"
138 }
139 fmt.Printf("%.6f, %s\n", result.Probability, sig)
140}
141
142func cmdJunkTest(c *cmd) {
143 c.unlisted = true
144 c.params = "hamdir spamdir"
145 c.help = "Check a directory with hams and one with spams against the junk filter, and report the success ratio."
146 a := junkFlags(c.flag)
147 args := c.Parse()
148 if len(args) != 2 {
149 c.Usage()
150 }
151 a.SetLogLevel()
152
153 f := must(junk.OpenFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath, false))
154 defer func() {
155 if err := f.Close(); err != nil {
156 log.Printf("closing junk filter: %v", err)
157 }
158 }()
159
160 testDir := func(dir string, ham bool) (int, int) {
161 ok, bad := 0, 0
162 files, err := os.ReadDir(dir)
163 xcheckf(err, "readdir %q", dir)
164 for _, fi := range files {
165 path := filepath.Join(dir, fi.Name())
166 result, err := f.ClassifyMessagePath(context.Background(), path)
167 if err != nil {
168 log.Printf("classify message %q: %s", path, err)
169 continue
170 }
171 if ham && result.Probability < a.spamThreshold || !ham && result.Probability > a.spamThreshold {
172 ok++
173 } else {
174 bad++
175 }
176 if ham && result.Probability > a.spamThreshold {
177 fmt.Printf("ham %q: %.4f\n", path, result.Probability)
178 }
179 if !ham && result.Probability < a.spamThreshold {
180 fmt.Printf("spam %q: %.4f\n", path, result.Probability)
181 }
182 }
183 return ok, bad
184 }
185
186 nhamok, nhambad := testDir(args[0], true)
187 nspamok, nspambad := testDir(args[1], false)
188 fmt.Printf("total ham, ok %d, bad %d\n", nhamok, nhambad)
189 fmt.Printf("total spam, ok %d, bad %d\n", nspamok, nspambad)
190 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
191 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
192 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
193}
194
195func cmdJunkAnalyze(c *cmd) {
196 c.unlisted = true
197 c.params = "hamdir spamdir"
198 c.help = `Analyze a directory with ham messages and one with spam messages.
199
200A part of the messages is used for training, and remaining for testing. The
201messages are shuffled, with optional random seed.`
202 a := junkFlags(c.flag)
203 args := c.Parse()
204 if len(args) != 2 {
205 c.Usage()
206 }
207 a.SetLogLevel()
208
209 f := must(junk.NewFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath))
210 defer func() {
211 if err := f.Close(); err != nil {
212 log.Printf("closing junk filter: %v", err)
213 }
214 }()
215
216 hamDir := args[0]
217 spamDir := args[1]
218 hamFiles := listDir(hamDir)
219 spamFiles := listDir(spamDir)
220
221 var seed int64
222 if a.seed {
223 seed = time.Now().UnixMilli()
224 }
225 // Still at math/rand (v1 instead of v2) for potential comparison to earlier test results.
226 rand := mathrand.New(mathrand.NewSource(seed))
227
228 shuffle := func(l []string) {
229 count := len(l)
230 for i := range l {
231 n := rand.Intn(count)
232 l[i], l[n] = l[n], l[i]
233 }
234 }
235
236 shuffle(hamFiles)
237 shuffle(spamFiles)
238
239 ntrainham := int(a.trainRatio * float64(len(hamFiles)))
240 ntrainspam := int(a.trainRatio * float64(len(spamFiles)))
241
242 trainHam := hamFiles[:ntrainham]
243 trainSpam := spamFiles[:ntrainspam]
244 testHam := hamFiles[ntrainham:]
245 testSpam := spamFiles[ntrainspam:]
246
247 var trainSent []string
248 if a.sentDir != "" {
249 trainSent = listDir(a.sentDir)
250 }
251
252 err := f.TrainDirs(hamDir, a.sentDir, spamDir, trainHam, trainSent, trainSpam)
253 xcheckf(err, "train")
254
255 testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) {
256 for _, name := range files {
257 path := filepath.Join(dir, name)
258 result, err := f.ClassifyMessagePath(context.Background(), path)
259 if err != nil {
260 // log.Infof("%s: %s", path, err)
261 malformed++
262 continue
263 }
264 if ham && result.Probability < a.spamThreshold || !ham && result.Probability > a.spamThreshold {
265 ok++
266 } else {
267 bad++
268 }
269 if ham && result.Probability > a.spamThreshold {
270 fmt.Printf("ham %q: %.4f\n", path, result.Probability)
271 }
272 if !ham && result.Probability < a.spamThreshold {
273 fmt.Printf("spam %q: %.4f\n", path, result.Probability)
274 }
275 }
276 return
277 }
278
279 nhamok, nhambad, nmalformedham := testDir(args[0], testHam, true)
280 nspamok, nspambad, nmalformedspam := testDir(args[1], testSpam, false)
281 fmt.Printf("training done, nham %d, nsent %d, nspam %d\n", ntrainham, len(trainSent), ntrainspam)
282 fmt.Printf("total ham, ok %d, bad %d, malformed %d\n", nhamok, nhambad, nmalformedham)
283 fmt.Printf("total spam, ok %d, bad %d, malformed %d\n", nspamok, nspambad, nmalformedspam)
284 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
285 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
286 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
287}
288
289func cmdJunkPlay(c *cmd) {
290 c.unlisted = true
291 c.params = "hamdir spamdir"
292 c.help = "Play messages from ham and spam directory according to their time of arrival and report on junk filter performance."
293 a := junkFlags(c.flag)
294 args := c.Parse()
295 if len(args) != 2 {
296 c.Usage()
297 }
298 a.SetLogLevel()
299
300 f := must(junk.NewFilter(context.Background(), c.log, a.params, a.databasePath, a.bloomfilterPath))
301 defer func() {
302 if err := f.Close(); err != nil {
303 log.Printf("closing junk filter: %v", err)
304 }
305 }()
306
307 // We'll go through all emails to find their dates.
308 type msg struct {
309 dir, filename string
310 ham, sent bool
311 t time.Time
312 }
313 var msgs []msg
314
315 var nbad, nnodate, nham, nspam, nsent int
316
317 scanDir := func(dir string, ham, sent bool) {
318 for _, name := range listDir(dir) {
319 path := filepath.Join(dir, name)
320 mf, err := os.Open(path)
321 xcheckf(err, "open %q", path)
322 fi, err := mf.Stat()
323 xcheckf(err, "stat %q", path)
324 p, err := message.EnsurePart(c.log.Logger, false, mf, fi.Size())
325 if err != nil {
326 nbad++
327 if err := mf.Close(); err != nil {
328 log.Printf("closing message file: %v", err)
329 }
330 continue
331 }
332 if p.Envelope.Date.IsZero() {
333 nnodate++
334 if err := mf.Close(); err != nil {
335 log.Printf("closing message file: %v", err)
336 }
337 continue
338 }
339 if err := mf.Close(); err != nil {
340 log.Printf("closing message file: %v", err)
341 }
342 msgs = append(msgs, msg{dir, name, ham, sent, p.Envelope.Date})
343 if sent {
344 nsent++
345 } else if ham {
346 nham++
347 } else {
348 nspam++
349 }
350 }
351 }
352
353 hamDir := args[0]
354 spamDir := args[1]
355 scanDir(hamDir, true, false)
356 scanDir(spamDir, false, false)
357 if a.sentDir != "" {
358 scanDir(a.sentDir, true, true)
359 }
360
361 // Sort the messages, earliest first.
362 sort.Slice(msgs, func(i, j int) bool {
363 return msgs[i].t.Before(msgs[j].t)
364 })
365
366 // Play all messages as if they are coming in. We predict their spaminess, check if
367 // we are right. And we train the system with the result.
368 var nhamok, nhambad, nspamok, nspambad int
369
370 play := func(msg msg) {
371 var words map[string]struct{}
372 path := filepath.Join(msg.dir, msg.filename)
373 if !msg.sent {
374 result, err := f.ClassifyMessagePath(context.Background(), path)
375 if err != nil {
376 nbad++
377 return
378 }
379 if msg.ham {
380 if result.Probability < a.spamThreshold {
381 nhamok++
382 } else {
383 nhambad++
384 }
385 } else {
386 if result.Probability > a.spamThreshold {
387 nspamok++
388 } else {
389 nspambad++
390 }
391 }
392 } else {
393 mf, err := os.Open(path)
394 xcheckf(err, "open %q", path)
395 defer func() {
396 if err := mf.Close(); err != nil {
397 log.Printf("closing message file: %v", err)
398 }
399 }()
400 fi, err := mf.Stat()
401 xcheckf(err, "stat %q", path)
402 p, err := message.EnsurePart(c.log.Logger, false, mf, fi.Size())
403 if err != nil {
404 log.Printf("bad sent message %q: %s", path, err)
405 return
406 }
407
408 words, err = f.ParseMessage(p)
409 if err != nil {
410 log.Printf("bad sent message %q: %s", path, err)
411 return
412 }
413 }
414
415 if err := f.Train(context.Background(), msg.ham, words); err != nil {
416 log.Printf("train: %s", err)
417 }
418 }
419
420 for _, m := range msgs {
421 play(m)
422 }
423
424 err := f.Save()
425 xcheckf(err, "saving filter")
426
427 fmt.Printf("completed, nham %d, nsent %d, nspam %d, nbad %d, nwithoutdate %d\n", nham, nsent, nspam, nbad, nnodate)
428 fmt.Printf("total ham, ok %d, bad %d\n", nhamok, nhambad)
429 fmt.Printf("total spam, ok %d, bad %d\n", nspamok, nspambad)
430 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
431 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
432 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
433}
434