1// Package junk implements a bayesian spam filter.
2//
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
8package junk
9
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
12
13import (
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "log/slog"
19 "math"
20 "os"
21 "path/filepath"
22 "sort"
23 "time"
24
25 "github.com/mjl-/bstore"
26
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
29)
30
31var (
32 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
33 errClosed = errors.New("filter is closed")
34)
35
36type word struct {
37 Ham uint32
38 Spam uint32
39}
40
41type wordscore struct {
42 Word string
43 Ham uint32
44 Spam uint32
45}
46
47// Params holds parameters for the filter. Most are at test-time. The first are
48// used during parsing and training.
49type Params struct {
50 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
51 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
52 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
53 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
54 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
55 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
56 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
57}
58
59var DBTypes = []any{wordscore{}} // Stored in DB.
60
61type Filter struct {
62 Params
63
64 log mlog.Log // For logging cid.
65 closed bool
66 modified bool // Whether any modifications are pending. Cleared by Save.
67 hams, spams uint32 // Message count, stored in db under word "-".
68 cache map[string]word // Words read from database or during training.
69 changed map[string]word // Words modified during training.
70 dbPath, bloomPath string
71 db *bstore.DB // Always open on a filter.
72 bloom *Bloom // Only opened when writing.
73 isNew bool // Set for new filters until their first sync to disk. For faster writing.
74}
75
76func (f *Filter) ensureBloom() error {
77 if f.bloom != nil {
78 return nil
79 }
80 var err error
81 f.bloom, err = openBloom(f.bloomPath)
82 return err
83}
84
85// CloseDiscard closes the filter, discarding any changes.
86func (f *Filter) CloseDiscard() error {
87 if f.closed {
88 return errClosed
89 }
90 err := f.db.Close()
91 *f = Filter{log: f.log, closed: true}
92 return err
93}
94
95// Close first saves the filter if it has modifications, then closes the database
96// connection and releases the bloom filter.
97func (f *Filter) Close() error {
98 if f.closed {
99 return errClosed
100 }
101 var err error
102 if f.modified {
103 err = f.Save()
104 }
105 if err != nil {
106 f.db.Close()
107 } else {
108 err = f.db.Close()
109 }
110 *f = Filter{log: f.log, closed: true}
111 return err
112}
113
114func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
115 var bloom *Bloom
116 if loadBloom {
117 var err error
118 bloom, err = openBloom(bloomPath)
119 if err != nil {
120 return nil, err
121 }
122 } else if fi, err := os.Stat(bloomPath); err == nil {
123 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
124 return nil, fmt.Errorf("bloom: %s", err)
125 }
126 }
127
128 db, err := openDB(ctx, dbPath)
129 if err != nil {
130 return nil, fmt.Errorf("open database: %s", err)
131 }
132
133 f := &Filter{
134 Params: params,
135 log: log,
136 cache: map[string]word{},
137 changed: map[string]word{},
138 dbPath: dbPath,
139 bloomPath: bloomPath,
140 db: db,
141 bloom: bloom,
142 }
143 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
144 wc := wordscore{Word: "-"}
145 err := tx.Get(&wc)
146 f.hams = wc.Ham
147 f.spams = wc.Spam
148 return err
149 })
150 if err != nil {
151 cerr := f.Close()
152 log.Check(cerr, "closing filter after error")
153 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
154 }
155 return f, nil
156}
157
158// NewFilter creates a new filter with empty bloom filter and database files. The
159// filter is marked as new until the first save, will be done automatically if
160// TrainDirs is called. If the bloom and/or database files exist, an error is
161// returned.
162func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
163 var err error
164 if _, err := os.Stat(bloomPath); err == nil {
165 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
166 } else if _, err := os.Stat(dbPath); err == nil {
167 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
168 }
169
170 bloomSizeBytes := 4 * 1024 * 1024
171 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
172 return nil, fmt.Errorf("bloom: %s", err)
173 }
174 bf, err := os.Create(bloomPath)
175 if err != nil {
176 return nil, fmt.Errorf("creating bloom file: %w", err)
177 }
178 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
179 xerr := bf.Close()
180 log.Check(xerr, "closing bloom filter file after truncate error")
181 xerr = os.Remove(bloomPath)
182 log.Check(xerr, "removing bloom filter file after truncate error")
183 return nil, fmt.Errorf("making empty bloom filter: %s", err)
184 }
185 err = bf.Close()
186 log.Check(err, "closing bloomfilter file")
187
188 db, err := newDB(ctx, log, dbPath)
189 if err != nil {
190 xerr := os.Remove(bloomPath)
191 log.Check(xerr, "removing bloom filter file after db init error")
192 xerr = os.Remove(dbPath)
193 log.Check(xerr, "removing database file after db init error")
194 return nil, fmt.Errorf("open database: %s", err)
195 }
196
197 words := map[string]word{} // f.changed is set to new map after training
198 f := &Filter{
199 Params: params,
200 log: log,
201 modified: true, // Ensure ham/spam message count is added for new filter.
202 cache: words,
203 changed: words,
204 dbPath: dbPath,
205 bloomPath: bloomPath,
206 db: db,
207 isNew: true,
208 }
209 return f, nil
210}
211
212const bloomK = 10
213
214func openBloom(path string) (*Bloom, error) {
215 buf, err := os.ReadFile(path)
216 if err != nil {
217 return nil, fmt.Errorf("reading bloom file: %w", err)
218 }
219 return NewBloom(buf, bloomK)
220}
221
222func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
223 // Remove any existing files.
224 os.Remove(path)
225
226 defer func() {
227 if rerr != nil {
228 err := os.Remove(path)
229 log.Check(err, "removing db file after init error")
230 }
231 }()
232
233 db, err := bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
234 if err != nil {
235 return nil, fmt.Errorf("open new database: %w", err)
236 }
237 return db, nil
238}
239
240func openDB(ctx context.Context, path string) (*bstore.DB, error) {
241 if _, err := os.Stat(path); err != nil {
242 return nil, fmt.Errorf("stat db file: %w", err)
243 }
244 return bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
245}
246
247// Save stores modifications, e.g. from training, to the database and bloom
248// filter files.
249func (f *Filter) Save() error {
250 if f.closed {
251 return errClosed
252 }
253 if !f.modified {
254 return nil
255 }
256
257 if f.bloom != nil && f.bloom.Modified() {
258 if err := f.bloom.Write(f.bloomPath); err != nil {
259 return fmt.Errorf("writing bloom filter: %w", err)
260 }
261 }
262
263 // We need to insert sequentially for reasonable performance.
264 words := make([]string, len(f.changed))
265 i := 0
266 for w := range f.changed {
267 words[i] = w
268 i++
269 }
270 sort.Slice(words, func(i, j int) bool {
271 return words[i] < words[j]
272 })
273
274 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
275 // start := time.Now()
276 if f.isNew {
277 if err := f.db.HintAppend(true, wordscore{}); err != nil {
278 f.log.Errorx("hint appendonly", err)
279 } else {
280 defer func() {
281 err := f.db.HintAppend(false, wordscore{})
282 f.log.Check(err, "restoring append hint")
283 }()
284 }
285 }
286 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
287 update := func(w string, ham, spam uint32) error {
288 if f.isNew {
289 return tx.Insert(&wordscore{w, ham, spam})
290 }
291
292 wc := wordscore{w, 0, 0}
293 err := tx.Get(&wc)
294 if err == bstore.ErrAbsent {
295 return tx.Insert(&wordscore{w, ham, spam})
296 } else if err != nil {
297 return err
298 }
299 return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
300 }
301 if err := update("-", f.hams, f.spams); err != nil {
302 return fmt.Errorf("storing total ham/spam message count: %s", err)
303 }
304
305 for _, w := range words {
306 c := f.changed[w]
307 if err := update(w, c.Ham, c.Spam); err != nil {
308 return fmt.Errorf("updating ham/spam count: %s", err)
309 }
310 }
311 return nil
312 })
313 if err != nil {
314 return fmt.Errorf("updating database: %w", err)
315 }
316
317 f.changed = map[string]word{}
318 f.modified = false
319 f.isNew = false
320 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
321 return nil
322}
323
324func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
325 sort.Slice(l, func(i, j int) bool {
326 return l[i] < l[j]
327 })
328
329 err := db.Read(ctx, func(tx *bstore.Tx) error {
330 for _, w := range l {
331 wc := wordscore{Word: w}
332 if err := tx.Get(&wc); err == nil {
333 dst[w] = word{wc.Ham, wc.Spam}
334 }
335 }
336 return nil
337 })
338 if err != nil {
339 return fmt.Errorf("fetching words: %s", err)
340 }
341 return nil
342}
343
344// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
345func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
346 if f.closed {
347 return 0, 0, 0, errClosed
348 }
349
350 type xword struct {
351 Word string
352 R float64
353 }
354
355 var hamHigh float64 = 0
356 var spamLow float64 = 1
357 var topHam []xword
358 var topSpam []xword
359
360 // Find words that should be in the database.
361 lookupWords := []string{}
362 expect := map[string]struct{}{}
363 unknowns := map[string]struct{}{}
364 totalUnknown := 0
365 for w := range words {
366 if f.bloom != nil && !f.bloom.Has(w) {
367 totalUnknown++
368 if len(unknowns) < 50 {
369 unknowns[w] = struct{}{}
370 }
371 continue
372 }
373 if _, ok := f.cache[w]; ok {
374 continue
375 }
376 lookupWords = append(lookupWords, w)
377 expect[w] = struct{}{}
378 }
379 if len(unknowns) > 0 {
380 f.log.Debug("unknown words in bloom filter, showing max 50",
381 slog.Any("words", unknowns),
382 slog.Any("totalunknown", totalUnknown),
383 slog.Any("totalwords", len(words)))
384 }
385
386 // Fetch words from database.
387 fetched := map[string]word{}
388 if len(lookupWords) > 0 {
389 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
390 return 0, 0, 0, err
391 }
392 for w, c := range fetched {
393 delete(expect, w)
394 f.cache[w] = c
395 }
396 f.log.Debug("unknown words in db",
397 slog.Any("words", expect),
398 slog.Any("totalunknown", len(expect)),
399 slog.Any("totalwords", len(words)))
400 }
401
402 for w := range words {
403 c, ok := f.cache[w]
404 if !ok {
405 continue
406 }
407 var wS, wH float64
408 if f.spams > 0 {
409 wS = float64(c.Spam) / float64(f.spams)
410 }
411 if f.hams > 0 {
412 wH = float64(c.Ham) / float64(f.hams)
413 }
414 r := wS / (wS + wH)
415
416 if r < f.MaxPower {
417 r = f.MaxPower
418 } else if r >= 1-f.MaxPower {
419 r = 1 - f.MaxPower
420 }
421
422 if c.Ham+c.Spam <= uint32(f.RareWords) {
423 // Reduce the power of rare words.
424 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
425 }
426 if math.Abs(0.5-r) < f.IgnoreWords {
427 continue
428 }
429 if r < 0.5 {
430 if len(topHam) >= f.TopWords && r > hamHigh {
431 continue
432 }
433 topHam = append(topHam, xword{w, r})
434 if r > hamHigh {
435 hamHigh = r
436 }
437 } else if r > 0.5 {
438 if len(topSpam) >= f.TopWords && r < spamLow {
439 continue
440 }
441 topSpam = append(topSpam, xword{w, r})
442 if r < spamLow {
443 spamLow = r
444 }
445 }
446 }
447
448 sort.Slice(topHam, func(i, j int) bool {
449 a, b := topHam[i], topHam[j]
450 if a.R == b.R {
451 return len(a.Word) > len(b.Word)
452 }
453 return a.R < b.R
454 })
455 sort.Slice(topSpam, func(i, j int) bool {
456 a, b := topSpam[i], topSpam[j]
457 if a.R == b.R {
458 return len(a.Word) > len(b.Word)
459 }
460 return a.R > b.R
461 })
462
463 nham = f.TopWords
464 if nham > len(topHam) {
465 nham = len(topHam)
466 }
467 nspam = f.TopWords
468 if nspam > len(topSpam) {
469 nspam = len(topSpam)
470 }
471 topHam = topHam[:nham]
472 topSpam = topSpam[:nspam]
473
474 var eta float64
475 for _, x := range topHam {
476 eta += math.Log(1-x.R) - math.Log(x.R)
477 }
478 for _, x := range topSpam {
479 eta += math.Log(1-x.R) - math.Log(x.R)
480 }
481
482 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
483
484 prob := 1 / (1 + math.Pow(math.E, eta))
485 return prob, len(topHam), len(topSpam), nil
486}
487
488// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
489func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
490 if f.closed {
491 return 0, nil, 0, 0, errClosed
492 }
493
494 mf, err := os.Open(path)
495 if err != nil {
496 return 0, nil, 0, 0, err
497 }
498 defer func() {
499 err := mf.Close()
500 f.log.Check(err, "closing file after classify")
501 }()
502 fi, err := mf.Stat()
503 if err != nil {
504 return 0, nil, 0, 0, err
505 }
506 return f.ClassifyMessageReader(ctx, mf, fi.Size())
507}
508
509func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
510 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
511 if err != nil && errors.Is(err, message.ErrBadContentType) {
512 // Invalid content-type header is a sure sign of spam.
513 //f.log.Infox("parsing content", err)
514 return 1, nil, 0, 0, nil
515 }
516 return f.ClassifyMessage(ctx, m)
517}
518
519// ClassifyMessage parses the mail message in r and returns the spam probability
520// (between 0 and 1), along with the tokenized words found in the message, and the
521// number of recognized ham and spam words.
522func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
523 var err error
524 words, err = f.ParseMessage(m)
525 if err != nil {
526 return 0, nil, 0, 0, err
527 }
528
529 probability, nham, nspam, err = f.ClassifyWords(ctx, words)
530 return probability, words, nham, nspam, err
531}
532
533// Train adds the words of a single message to the filter.
534func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
535 if err := f.ensureBloom(); err != nil {
536 return err
537 }
538
539 var lwords []string
540
541 for w := range words {
542 if !f.bloom.Has(w) {
543 f.bloom.Add(w)
544 continue
545 }
546 if _, ok := f.cache[w]; !ok {
547 lwords = append(lwords, w)
548 }
549 }
550
551 if err := f.loadCache(ctx, lwords); err != nil {
552 return err
553 }
554
555 f.modified = true
556 if ham {
557 f.hams++
558 } else {
559 f.spams++
560 }
561
562 for w := range words {
563 c := f.cache[w]
564 if ham {
565 c.Ham++
566 } else {
567 c.Spam++
568 }
569 f.cache[w] = c
570 f.changed[w] = c
571 }
572 return nil
573}
574
575func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
576 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
577 words, err := f.ParseMessage(p)
578 if err != nil {
579 return fmt.Errorf("parsing mail contents: %v", err)
580 }
581 return f.Train(ctx, ham, words)
582}
583
584func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
585 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
586 words, err := f.ParseMessage(p)
587 if err != nil {
588 return fmt.Errorf("parsing mail contents: %v", err)
589 }
590 return f.Untrain(ctx, ham, words)
591}
592
593func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
594 if len(lwords) == 0 {
595 return nil
596 }
597
598 return loadWords(ctx, f.db, lwords, f.cache)
599}
600
601// Untrain adjusts the filter to undo a previous training of the words.
602func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
603 if err := f.ensureBloom(); err != nil {
604 return err
605 }
606
607 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
608 var lwords []string
609 for w := range words {
610 if _, ok := f.cache[w]; !ok {
611 lwords = append(lwords, w)
612 }
613 }
614 if err := f.loadCache(ctx, lwords); err != nil {
615 return err
616 }
617
618 // Modify the message count.
619 f.modified = true
620 if ham {
621 f.hams--
622 } else {
623 f.spams--
624 }
625
626 // Decrease the word counts.
627 for w := range words {
628 c, ok := f.cache[w]
629 if !ok {
630 continue
631 }
632 if ham {
633 c.Ham--
634 } else {
635 c.Spam--
636 }
637 f.cache[w] = c
638 f.changed[w] = c
639 }
640 return nil
641}
642
643// TrainDir parses mail messages from files and trains the filter.
644func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
645 if f.closed {
646 return 0, 0, errClosed
647 }
648 if err := f.ensureBloom(); err != nil {
649 return 0, 0, err
650 }
651
652 for _, name := range files {
653 p := filepath.Join(dir, name)
654 valid, words, err := f.tokenizeMail(p)
655 if err != nil {
656 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
657 malformed++
658 continue
659 }
660 if !valid {
661 continue
662 }
663 n++
664 for w := range words {
665 if !f.bloom.Has(w) {
666 f.bloom.Add(w)
667 continue
668 }
669 c := f.cache[w]
670 f.modified = true
671 if ham {
672 c.Ham++
673 } else {
674 c.Spam++
675 }
676 f.cache[w] = c
677 f.changed[w] = c
678 }
679 }
680 return
681}
682
683// TrainDirs trains and saves a filter with mail messages from different types
684// of directories.
685func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
686 if f.closed {
687 return errClosed
688 }
689
690 var err error
691
692 var start time.Time
693 var hamMalformed, sentMalformed, spamMalformed uint32
694
695 start = time.Now()
696 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
697 if err != nil {
698 return err
699 }
700 tham := time.Since(start)
701
702 var sent uint32
703 start = time.Now()
704 if sentDir != "" {
705 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
706 if err != nil {
707 return err
708 }
709 }
710 tsent := time.Since(start)
711
712 start = time.Now()
713 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
714 if err != nil {
715 return err
716 }
717 tspam := time.Since(start)
718
719 hams := f.hams
720 f.hams += sent
721 if err := f.Save(); err != nil {
722 return fmt.Errorf("saving filter: %s", err)
723 }
724
725 dbSize := f.fileSize(f.dbPath)
726 bloomSize := f.fileSize(f.bloomPath)
727
728 f.log.Print("training done",
729 slog.Any("hams", hams),
730 slog.Any("hamtime", tham),
731 slog.Any("hammalformed", hamMalformed),
732 slog.Any("sent", sent),
733 slog.Any("senttime", tsent),
734 slog.Any("sentmalformed", sentMalformed),
735 slog.Any("spams", f.spams),
736 slog.Any("spamtime", tspam),
737 slog.Any("spammalformed", spamMalformed),
738 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
739 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
740 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
741 )
742
743 return nil
744}
745
746func (f *Filter) fileSize(p string) int {
747 fi, err := os.Stat(p)
748 if err != nil {
749 f.log.Infox("stat", err, slog.Any("path", p))
750 return 0
751 }
752 return int(fi.Size())
753}
754
755// DB returns the database, for backups.
756func (f *Filter) DB() *bstore.DB {
757 return f.db
758}
759