1// Package junk implements a bayesian spam filter.
2//
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
8package junk
9
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
12
13import (
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "log/slog"
19 "math"
20 "os"
21 "path/filepath"
22 "sort"
23 "time"
24
25 "github.com/mjl-/bstore"
26
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
29)
30
31var (
32 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
33 errClosed = errors.New("filter is closed")
34)
35
36type word struct {
37 Ham uint32
38 Spam uint32
39}
40
41type wordscore struct {
42 Word string
43 Ham uint32
44 Spam uint32
45}
46
47// Params holds parameters for the filter. Most are at test-time. The first are
48// used during parsing and training.
49type Params struct {
50 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
51 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
52 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
53 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
54 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
55 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
56 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
57}
58
59var DBTypes = []any{wordscore{}} // Stored in DB.
60
61type Filter struct {
62 Params
63
64 log mlog.Log // For logging cid.
65 closed bool
66 modified bool // Whether any modifications are pending. Cleared by Save.
67 hams, spams uint32 // Message count, stored in db under word "-".
68 cache map[string]word // Words read from database or during training.
69 changed map[string]word // Words modified during training.
70 dbPath, bloomPath string
71 db *bstore.DB // Always open on a filter.
72 bloom *Bloom // Only opened when writing.
73 isNew bool // Set for new filters until their first sync to disk. For faster writing.
74}
75
76func (f *Filter) ensureBloom() error {
77 if f.bloom != nil {
78 return nil
79 }
80 var err error
81 f.bloom, err = openBloom(f.bloomPath)
82 return err
83}
84
85// CloseDiscard closes the filter, discarding any changes.
86func (f *Filter) CloseDiscard() error {
87 if f.closed {
88 return errClosed
89 }
90 err := f.db.Close()
91 *f = Filter{log: f.log, closed: true}
92 return err
93}
94
95// Close first saves the filter if it has modifications, then closes the database
96// connection and releases the bloom filter.
97func (f *Filter) Close() error {
98 if f.closed {
99 return errClosed
100 }
101 var err error
102 if f.modified {
103 err = f.Save()
104 }
105 if err != nil {
106 f.db.Close()
107 } else {
108 err = f.db.Close()
109 }
110 *f = Filter{log: f.log, closed: true}
111 return err
112}
113
114func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
115 var bloom *Bloom
116 if loadBloom {
117 var err error
118 bloom, err = openBloom(bloomPath)
119 if err != nil {
120 return nil, err
121 }
122 } else if fi, err := os.Stat(bloomPath); err == nil {
123 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
124 return nil, fmt.Errorf("bloom: %s", err)
125 }
126 }
127
128 db, err := openDB(ctx, log, dbPath)
129 if err != nil {
130 return nil, fmt.Errorf("open database: %s", err)
131 }
132
133 f := &Filter{
134 Params: params,
135 log: log,
136 cache: map[string]word{},
137 changed: map[string]word{},
138 dbPath: dbPath,
139 bloomPath: bloomPath,
140 db: db,
141 bloom: bloom,
142 }
143 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
144 wc := wordscore{Word: "-"}
145 err := tx.Get(&wc)
146 f.hams = wc.Ham
147 f.spams = wc.Spam
148 return err
149 })
150 if err != nil {
151 cerr := f.Close()
152 log.Check(cerr, "closing filter after error")
153 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
154 }
155 return f, nil
156}
157
158// NewFilter creates a new filter with empty bloom filter and database files. The
159// filter is marked as new until the first save, will be done automatically if
160// TrainDirs is called. If the bloom and/or database files exist, an error is
161// returned.
162func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
163 var err error
164 if _, err := os.Stat(bloomPath); err == nil {
165 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
166 } else if _, err := os.Stat(dbPath); err == nil {
167 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
168 }
169
170 bloomSizeBytes := 4 * 1024 * 1024
171 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
172 return nil, fmt.Errorf("bloom: %s", err)
173 }
174 bf, err := os.Create(bloomPath)
175 if err != nil {
176 return nil, fmt.Errorf("creating bloom file: %w", err)
177 }
178 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
179 xerr := bf.Close()
180 log.Check(xerr, "closing bloom filter file after truncate error")
181 xerr = os.Remove(bloomPath)
182 log.Check(xerr, "removing bloom filter file after truncate error")
183 return nil, fmt.Errorf("making empty bloom filter: %s", err)
184 }
185 err = bf.Close()
186 log.Check(err, "closing bloomfilter file")
187
188 db, err := newDB(ctx, log, dbPath)
189 if err != nil {
190 xerr := os.Remove(bloomPath)
191 log.Check(xerr, "removing bloom filter file after db init error")
192 xerr = os.Remove(dbPath)
193 log.Check(xerr, "removing database file after db init error")
194 return nil, fmt.Errorf("open database: %s", err)
195 }
196
197 words := map[string]word{} // f.changed is set to new map after training
198 f := &Filter{
199 Params: params,
200 log: log,
201 modified: true, // Ensure ham/spam message count is added for new filter.
202 cache: words,
203 changed: words,
204 dbPath: dbPath,
205 bloomPath: bloomPath,
206 db: db,
207 isNew: true,
208 }
209 return f, nil
210}
211
212const bloomK = 10
213
214func openBloom(path string) (*Bloom, error) {
215 buf, err := os.ReadFile(path)
216 if err != nil {
217 return nil, fmt.Errorf("reading bloom file: %w", err)
218 }
219 return NewBloom(buf, bloomK)
220}
221
222func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
223 // Remove any existing files.
224 os.Remove(path)
225
226 defer func() {
227 if rerr != nil {
228 err := os.Remove(path)
229 log.Check(err, "removing db file after init error")
230 }
231 }()
232
233 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
234 db, err := bstore.Open(ctx, path, &opts, DBTypes...)
235 if err != nil {
236 return nil, fmt.Errorf("open new database: %w", err)
237 }
238 return db, nil
239}
240
241func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
242 if _, err := os.Stat(path); err != nil {
243 return nil, fmt.Errorf("stat db file: %w", err)
244 }
245 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
246 return bstore.Open(ctx, path, &opts, DBTypes...)
247}
248
249// Save stores modifications, e.g. from training, to the database and bloom
250// filter files.
251func (f *Filter) Save() error {
252 if f.closed {
253 return errClosed
254 }
255 if !f.modified {
256 return nil
257 }
258
259 if f.bloom != nil && f.bloom.Modified() {
260 if err := f.bloom.Write(f.bloomPath); err != nil {
261 return fmt.Errorf("writing bloom filter: %w", err)
262 }
263 }
264
265 // We need to insert sequentially for reasonable performance.
266 words := make([]string, len(f.changed))
267 i := 0
268 for w := range f.changed {
269 words[i] = w
270 i++
271 }
272 sort.Slice(words, func(i, j int) bool {
273 return words[i] < words[j]
274 })
275
276 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
277 // start := time.Now()
278 if f.isNew {
279 if err := f.db.HintAppend(true, wordscore{}); err != nil {
280 f.log.Errorx("hint appendonly", err)
281 } else {
282 defer func() {
283 err := f.db.HintAppend(false, wordscore{})
284 f.log.Check(err, "restoring append hint")
285 }()
286 }
287 }
288 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
289 update := func(w string, ham, spam uint32) error {
290 if f.isNew {
291 return tx.Insert(&wordscore{w, ham, spam})
292 }
293
294 wc := wordscore{w, 0, 0}
295 err := tx.Get(&wc)
296 if err == bstore.ErrAbsent {
297 return tx.Insert(&wordscore{w, ham, spam})
298 } else if err != nil {
299 return err
300 }
301 return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
302 }
303 if err := update("-", f.hams, f.spams); err != nil {
304 return fmt.Errorf("storing total ham/spam message count: %s", err)
305 }
306
307 for _, w := range words {
308 c := f.changed[w]
309 if err := update(w, c.Ham, c.Spam); err != nil {
310 return fmt.Errorf("updating ham/spam count: %s", err)
311 }
312 }
313 return nil
314 })
315 if err != nil {
316 return fmt.Errorf("updating database: %w", err)
317 }
318
319 f.changed = map[string]word{}
320 f.modified = false
321 f.isNew = false
322 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
323 return nil
324}
325
326func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
327 sort.Slice(l, func(i, j int) bool {
328 return l[i] < l[j]
329 })
330
331 err := db.Read(ctx, func(tx *bstore.Tx) error {
332 for _, w := range l {
333 wc := wordscore{Word: w}
334 if err := tx.Get(&wc); err == nil {
335 dst[w] = word{wc.Ham, wc.Spam}
336 }
337 }
338 return nil
339 })
340 if err != nil {
341 return fmt.Errorf("fetching words: %s", err)
342 }
343 return nil
344}
345
346// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
347func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
348 if f.closed {
349 return 0, 0, 0, errClosed
350 }
351
352 type xword struct {
353 Word string
354 R float64
355 }
356
357 var hamHigh float64 = 0
358 var spamLow float64 = 1
359 var topHam []xword
360 var topSpam []xword
361
362 // Find words that should be in the database.
363 lookupWords := []string{}
364 expect := map[string]struct{}{}
365 unknowns := map[string]struct{}{}
366 totalUnknown := 0
367 for w := range words {
368 if f.bloom != nil && !f.bloom.Has(w) {
369 totalUnknown++
370 if len(unknowns) < 50 {
371 unknowns[w] = struct{}{}
372 }
373 continue
374 }
375 if _, ok := f.cache[w]; ok {
376 continue
377 }
378 lookupWords = append(lookupWords, w)
379 expect[w] = struct{}{}
380 }
381 if len(unknowns) > 0 {
382 f.log.Debug("unknown words in bloom filter, showing max 50",
383 slog.Any("words", unknowns),
384 slog.Any("totalunknown", totalUnknown),
385 slog.Any("totalwords", len(words)))
386 }
387
388 // Fetch words from database.
389 fetched := map[string]word{}
390 if len(lookupWords) > 0 {
391 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
392 return 0, 0, 0, err
393 }
394 for w, c := range fetched {
395 delete(expect, w)
396 f.cache[w] = c
397 }
398 f.log.Debug("unknown words in db",
399 slog.Any("words", expect),
400 slog.Any("totalunknown", len(expect)),
401 slog.Any("totalwords", len(words)))
402 }
403
404 for w := range words {
405 c, ok := f.cache[w]
406 if !ok {
407 continue
408 }
409 var wS, wH float64
410 if f.spams > 0 {
411 wS = float64(c.Spam) / float64(f.spams)
412 }
413 if f.hams > 0 {
414 wH = float64(c.Ham) / float64(f.hams)
415 }
416 r := wS / (wS + wH)
417
418 if r < f.MaxPower {
419 r = f.MaxPower
420 } else if r >= 1-f.MaxPower {
421 r = 1 - f.MaxPower
422 }
423
424 if c.Ham+c.Spam <= uint32(f.RareWords) {
425 // Reduce the power of rare words.
426 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
427 }
428 if math.Abs(0.5-r) < f.IgnoreWords {
429 continue
430 }
431 if r < 0.5 {
432 if len(topHam) >= f.TopWords && r > hamHigh {
433 continue
434 }
435 topHam = append(topHam, xword{w, r})
436 if r > hamHigh {
437 hamHigh = r
438 }
439 } else if r > 0.5 {
440 if len(topSpam) >= f.TopWords && r < spamLow {
441 continue
442 }
443 topSpam = append(topSpam, xword{w, r})
444 if r < spamLow {
445 spamLow = r
446 }
447 }
448 }
449
450 sort.Slice(topHam, func(i, j int) bool {
451 a, b := topHam[i], topHam[j]
452 if a.R == b.R {
453 return len(a.Word) > len(b.Word)
454 }
455 return a.R < b.R
456 })
457 sort.Slice(topSpam, func(i, j int) bool {
458 a, b := topSpam[i], topSpam[j]
459 if a.R == b.R {
460 return len(a.Word) > len(b.Word)
461 }
462 return a.R > b.R
463 })
464
465 nham = f.TopWords
466 if nham > len(topHam) {
467 nham = len(topHam)
468 }
469 nspam = f.TopWords
470 if nspam > len(topSpam) {
471 nspam = len(topSpam)
472 }
473 topHam = topHam[:nham]
474 topSpam = topSpam[:nspam]
475
476 var eta float64
477 for _, x := range topHam {
478 eta += math.Log(1-x.R) - math.Log(x.R)
479 }
480 for _, x := range topSpam {
481 eta += math.Log(1-x.R) - math.Log(x.R)
482 }
483
484 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
485
486 prob := 1 / (1 + math.Pow(math.E, eta))
487 return prob, len(topHam), len(topSpam), nil
488}
489
490// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
491func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
492 if f.closed {
493 return 0, nil, 0, 0, errClosed
494 }
495
496 mf, err := os.Open(path)
497 if err != nil {
498 return 0, nil, 0, 0, err
499 }
500 defer func() {
501 err := mf.Close()
502 f.log.Check(err, "closing file after classify")
503 }()
504 fi, err := mf.Stat()
505 if err != nil {
506 return 0, nil, 0, 0, err
507 }
508 return f.ClassifyMessageReader(ctx, mf, fi.Size())
509}
510
511func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
512 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
513 if err != nil && errors.Is(err, message.ErrBadContentType) {
514 // Invalid content-type header is a sure sign of spam.
515 //f.log.Infox("parsing content", err)
516 return 1, nil, 0, 0, nil
517 }
518 return f.ClassifyMessage(ctx, m)
519}
520
521// ClassifyMessage parses the mail message in r and returns the spam probability
522// (between 0 and 1), along with the tokenized words found in the message, and the
523// number of recognized ham and spam words.
524func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
525 var err error
526 words, err = f.ParseMessage(m)
527 if err != nil {
528 return 0, nil, 0, 0, err
529 }
530
531 probability, nham, nspam, err = f.ClassifyWords(ctx, words)
532 return probability, words, nham, nspam, err
533}
534
535// Train adds the words of a single message to the filter.
536func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
537 if err := f.ensureBloom(); err != nil {
538 return err
539 }
540
541 var lwords []string
542
543 for w := range words {
544 if !f.bloom.Has(w) {
545 f.bloom.Add(w)
546 continue
547 }
548 if _, ok := f.cache[w]; !ok {
549 lwords = append(lwords, w)
550 }
551 }
552
553 if err := f.loadCache(ctx, lwords); err != nil {
554 return err
555 }
556
557 f.modified = true
558 if ham {
559 f.hams++
560 } else {
561 f.spams++
562 }
563
564 for w := range words {
565 c := f.cache[w]
566 if ham {
567 c.Ham++
568 } else {
569 c.Spam++
570 }
571 f.cache[w] = c
572 f.changed[w] = c
573 }
574 return nil
575}
576
577func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
578 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
579 words, err := f.ParseMessage(p)
580 if err != nil {
581 return fmt.Errorf("parsing mail contents: %v", err)
582 }
583 return f.Train(ctx, ham, words)
584}
585
586func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
587 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
588 words, err := f.ParseMessage(p)
589 if err != nil {
590 return fmt.Errorf("parsing mail contents: %v", err)
591 }
592 return f.Untrain(ctx, ham, words)
593}
594
595func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
596 if len(lwords) == 0 {
597 return nil
598 }
599
600 return loadWords(ctx, f.db, lwords, f.cache)
601}
602
603// Untrain adjusts the filter to undo a previous training of the words.
604func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
605 if err := f.ensureBloom(); err != nil {
606 return err
607 }
608
609 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
610 var lwords []string
611 for w := range words {
612 if _, ok := f.cache[w]; !ok {
613 lwords = append(lwords, w)
614 }
615 }
616 if err := f.loadCache(ctx, lwords); err != nil {
617 return err
618 }
619
620 // Modify the message count.
621 f.modified = true
622 if ham {
623 f.hams--
624 } else {
625 f.spams--
626 }
627
628 // Decrease the word counts.
629 for w := range words {
630 c, ok := f.cache[w]
631 if !ok {
632 continue
633 }
634 if ham {
635 c.Ham--
636 } else {
637 c.Spam--
638 }
639 f.cache[w] = c
640 f.changed[w] = c
641 }
642 return nil
643}
644
645// TrainDir parses mail messages from files and trains the filter.
646func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
647 if f.closed {
648 return 0, 0, errClosed
649 }
650 if err := f.ensureBloom(); err != nil {
651 return 0, 0, err
652 }
653
654 for _, name := range files {
655 p := filepath.Join(dir, name)
656 valid, words, err := f.tokenizeMail(p)
657 if err != nil {
658 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
659 malformed++
660 continue
661 }
662 if !valid {
663 continue
664 }
665 n++
666 for w := range words {
667 if !f.bloom.Has(w) {
668 f.bloom.Add(w)
669 continue
670 }
671 c := f.cache[w]
672 f.modified = true
673 if ham {
674 c.Ham++
675 } else {
676 c.Spam++
677 }
678 f.cache[w] = c
679 f.changed[w] = c
680 }
681 }
682 return
683}
684
685// TrainDirs trains and saves a filter with mail messages from different types
686// of directories.
687func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
688 if f.closed {
689 return errClosed
690 }
691
692 var err error
693
694 var start time.Time
695 var hamMalformed, sentMalformed, spamMalformed uint32
696
697 start = time.Now()
698 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
699 if err != nil {
700 return err
701 }
702 tham := time.Since(start)
703
704 var sent uint32
705 start = time.Now()
706 if sentDir != "" {
707 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
708 if err != nil {
709 return err
710 }
711 }
712 tsent := time.Since(start)
713
714 start = time.Now()
715 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
716 if err != nil {
717 return err
718 }
719 tspam := time.Since(start)
720
721 hams := f.hams
722 f.hams += sent
723 if err := f.Save(); err != nil {
724 return fmt.Errorf("saving filter: %s", err)
725 }
726
727 dbSize := f.fileSize(f.dbPath)
728 bloomSize := f.fileSize(f.bloomPath)
729
730 f.log.Print("training done",
731 slog.Any("hams", hams),
732 slog.Any("hamtime", tham),
733 slog.Any("hammalformed", hamMalformed),
734 slog.Any("sent", sent),
735 slog.Any("senttime", tsent),
736 slog.Any("sentmalformed", sentMalformed),
737 slog.Any("spams", f.spams),
738 slog.Any("spamtime", tspam),
739 slog.Any("spammalformed", spamMalformed),
740 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
741 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
742 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
743 )
744
745 return nil
746}
747
748func (f *Filter) fileSize(p string) int {
749 fi, err := os.Stat(p)
750 if err != nil {
751 f.log.Infox("stat", err, slog.Any("path", p))
752 return 0
753 }
754 return int(fi.Size())
755}
756
757// DB returns the database, for backups.
758func (f *Filter) DB() *bstore.DB {
759 return f.db
760}
761