1// Package junk implements a bayesian spam filter.
2//
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
8package junk
9
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
12
13import (
14 "context"
15 "errors"
16 "fmt"
17 "io"
18 "log/slog"
19 "math"
20 "os"
21 "path/filepath"
22 "sort"
23 "time"
24
25 "github.com/mjl-/bstore"
26
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
29)
30
31var (
32 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
33 errClosed = errors.New("filter is closed")
34)
35
36type word struct {
37 Ham uint32
38 Spam uint32
39}
40
41type wordscore struct {
42 Word string
43 Ham uint32
44 Spam uint32
45}
46
47// Params holds parameters for the filter. Most are at test-time. The first are
48// used during parsing and training.
49type Params struct {
50 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
51 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
52 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
53 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
54 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
55 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
56 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
57}
58
59var DBTypes = []any{wordscore{}} // Stored in DB.
60
61type Filter struct {
62 Params
63
64 log mlog.Log // For logging cid.
65 closed bool
66 modified bool // Whether any modifications are pending. Cleared by Save.
67 hams, spams uint32 // Message count, stored in db under word "-".
68 cache map[string]word // Words read from database or during training.
69 changed map[string]word // Words modified during training.
70 dbPath, bloomPath string
71 db *bstore.DB // Always open on a filter.
72 bloom *Bloom // Only opened when writing.
73 isNew bool // Set for new filters until their first sync to disk. For faster writing.
74}
75
76func (f *Filter) ensureBloom() error {
77 if f.bloom != nil {
78 return nil
79 }
80 var err error
81 f.bloom, err = openBloom(f.bloomPath)
82 return err
83}
84
85// CloseDiscard closes the filter, discarding any changes.
86func (f *Filter) CloseDiscard() error {
87 if f.closed {
88 return errClosed
89 }
90 err := f.db.Close()
91 *f = Filter{log: f.log, closed: true}
92 return err
93}
94
95// Close first saves the filter if it has modifications, then closes the database
96// connection and releases the bloom filter.
97func (f *Filter) Close() error {
98 if f.closed {
99 return errClosed
100 }
101 var err error
102 if f.modified {
103 err = f.Save()
104 }
105 if err != nil {
106 f.db.Close()
107 } else {
108 err = f.db.Close()
109 }
110 *f = Filter{log: f.log, closed: true}
111 return err
112}
113
114func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
115 var bloom *Bloom
116 if loadBloom {
117 var err error
118 bloom, err = openBloom(bloomPath)
119 if err != nil {
120 return nil, err
121 }
122 } else if fi, err := os.Stat(bloomPath); err == nil {
123 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
124 return nil, fmt.Errorf("bloom: %s", err)
125 }
126 }
127
128 db, err := openDB(ctx, log, dbPath)
129 if err != nil {
130 return nil, fmt.Errorf("open database: %s", err)
131 }
132
133 f := &Filter{
134 Params: params,
135 log: log,
136 cache: map[string]word{},
137 changed: map[string]word{},
138 dbPath: dbPath,
139 bloomPath: bloomPath,
140 db: db,
141 bloom: bloom,
142 }
143 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
144 wc := wordscore{Word: "-"}
145 err := tx.Get(&wc)
146 f.hams = wc.Ham
147 f.spams = wc.Spam
148 return err
149 })
150 if err != nil {
151 cerr := f.Close()
152 log.Check(cerr, "closing filter after error")
153 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
154 }
155 return f, nil
156}
157
158// NewFilter creates a new filter with empty bloom filter and database files. The
159// filter is marked as new until the first save, will be done automatically if
160// TrainDirs is called. If the bloom and/or database files exist, an error is
161// returned.
162func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
163 var err error
164 if _, err := os.Stat(bloomPath); err == nil {
165 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
166 } else if _, err := os.Stat(dbPath); err == nil {
167 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
168 }
169
170 bloomSizeBytes := 4 * 1024 * 1024
171 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
172 return nil, fmt.Errorf("bloom: %s", err)
173 }
174 bf, err := os.Create(bloomPath)
175 if err != nil {
176 return nil, fmt.Errorf("creating bloom file: %w", err)
177 }
178 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
179 xerr := bf.Close()
180 log.Check(xerr, "closing bloom filter file after truncate error")
181 xerr = os.Remove(bloomPath)
182 log.Check(xerr, "removing bloom filter file after truncate error")
183 return nil, fmt.Errorf("making empty bloom filter: %s", err)
184 }
185 err = bf.Close()
186 log.Check(err, "closing bloomfilter file")
187
188 db, err := newDB(ctx, log, dbPath)
189 if err != nil {
190 xerr := os.Remove(bloomPath)
191 log.Check(xerr, "removing bloom filter file after db init error")
192 xerr = os.Remove(dbPath)
193 log.Check(xerr, "removing database file after db init error")
194 return nil, fmt.Errorf("open database: %s", err)
195 }
196
197 words := map[string]word{} // f.changed is set to new map after training
198 f := &Filter{
199 Params: params,
200 log: log,
201 modified: true, // Ensure ham/spam message count is added for new filter.
202 cache: words,
203 changed: words,
204 dbPath: dbPath,
205 bloomPath: bloomPath,
206 db: db,
207 isNew: true,
208 }
209 return f, nil
210}
211
212const bloomK = 10
213
214func openBloom(path string) (*Bloom, error) {
215 buf, err := os.ReadFile(path)
216 if err != nil {
217 return nil, fmt.Errorf("reading bloom file: %w", err)
218 }
219 return NewBloom(buf, bloomK)
220}
221
222func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
223 // Remove any existing files.
224 os.Remove(path)
225
226 defer func() {
227 if rerr != nil {
228 err := os.Remove(path)
229 log.Check(err, "removing db file after init error")
230 }
231 }()
232
233 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
234 db, err := bstore.Open(ctx, path, &opts, DBTypes...)
235 if err != nil {
236 return nil, fmt.Errorf("open new database: %w", err)
237 }
238 return db, nil
239}
240
241func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
242 if _, err := os.Stat(path); err != nil {
243 return nil, fmt.Errorf("stat db file: %w", err)
244 }
245 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
246 return bstore.Open(ctx, path, &opts, DBTypes...)
247}
248
249// Save stores modifications, e.g. from training, to the database and bloom
250// filter files.
251func (f *Filter) Save() error {
252 if f.closed {
253 return errClosed
254 }
255 if !f.modified {
256 return nil
257 }
258
259 if f.bloom != nil && f.bloom.Modified() {
260 if err := f.bloom.Write(f.bloomPath); err != nil {
261 return fmt.Errorf("writing bloom filter: %w", err)
262 }
263 }
264
265 // We need to insert sequentially for reasonable performance.
266 words := make([]string, len(f.changed))
267 i := 0
268 for w := range f.changed {
269 words[i] = w
270 i++
271 }
272 sort.Slice(words, func(i, j int) bool {
273 return words[i] < words[j]
274 })
275
276 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
277 // start := time.Now()
278 if f.isNew {
279 if err := f.db.HintAppend(true, wordscore{}); err != nil {
280 f.log.Errorx("hint appendonly", err)
281 } else {
282 defer func() {
283 err := f.db.HintAppend(false, wordscore{})
284 f.log.Check(err, "restoring append hint")
285 }()
286 }
287 }
288 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
289 update := func(w string, ham, spam uint32) error {
290 if f.isNew {
291 return tx.Insert(&wordscore{w, ham, spam})
292 }
293
294 wc := wordscore{w, 0, 0}
295 err := tx.Get(&wc)
296 if err == bstore.ErrAbsent {
297 return tx.Insert(&wordscore{w, ham, spam})
298 } else if err != nil {
299 return err
300 }
301 return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
302 }
303 if err := update("-", f.hams, f.spams); err != nil {
304 return fmt.Errorf("storing total ham/spam message count: %s", err)
305 }
306
307 for _, w := range words {
308 c := f.changed[w]
309 if err := update(w, c.Ham, c.Spam); err != nil {
310 return fmt.Errorf("updating ham/spam count: %s", err)
311 }
312 }
313 return nil
314 })
315 if err != nil {
316 return fmt.Errorf("updating database: %w", err)
317 }
318
319 f.changed = map[string]word{}
320 f.modified = false
321 f.isNew = false
322 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
323 return nil
324}
325
326func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
327 sort.Slice(l, func(i, j int) bool {
328 return l[i] < l[j]
329 })
330
331 err := db.Read(ctx, func(tx *bstore.Tx) error {
332 for _, w := range l {
333 wc := wordscore{Word: w}
334 if err := tx.Get(&wc); err == nil {
335 dst[w] = word{wc.Ham, wc.Spam}
336 }
337 }
338 return nil
339 })
340 if err != nil {
341 return fmt.Errorf("fetching words: %s", err)
342 }
343 return nil
344}
345
346// WordScore is a word with its score as used in classifications, based on
347// (historic) training.
348type WordScore struct {
349 Word string
350 Score float64 // 0 is ham, 1 is spam.
351}
352
353// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
354func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, hams, spams []WordScore, rerr error) {
355 if f.closed {
356 return 0, nil, nil, errClosed
357 }
358
359 var hamHigh float64 = 0
360 var spamLow float64 = 1
361 var topHam []WordScore
362 var topSpam []WordScore
363
364 // Find words that should be in the database.
365 lookupWords := []string{}
366 expect := map[string]struct{}{}
367 unknowns := map[string]struct{}{}
368 totalUnknown := 0
369 for w := range words {
370 if f.bloom != nil && !f.bloom.Has(w) {
371 totalUnknown++
372 if len(unknowns) < 50 {
373 unknowns[w] = struct{}{}
374 }
375 continue
376 }
377 if _, ok := f.cache[w]; ok {
378 continue
379 }
380 lookupWords = append(lookupWords, w)
381 expect[w] = struct{}{}
382 }
383 if len(unknowns) > 0 {
384 f.log.Debug("unknown words in bloom filter, showing max 50",
385 slog.Any("words", unknowns),
386 slog.Any("totalunknown", totalUnknown),
387 slog.Any("totalwords", len(words)))
388 }
389
390 // Fetch words from database.
391 fetched := map[string]word{}
392 if len(lookupWords) > 0 {
393 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
394 return 0, nil, nil, err
395 }
396 for w, c := range fetched {
397 delete(expect, w)
398 f.cache[w] = c
399 }
400 f.log.Debug("unknown words in db",
401 slog.Any("words", expect),
402 slog.Any("totalunknown", len(expect)),
403 slog.Any("totalwords", len(words)))
404 }
405
406 for w := range words {
407 c, ok := f.cache[w]
408 if !ok {
409 continue
410 }
411 var wS, wH float64
412 if f.spams > 0 {
413 wS = float64(c.Spam) / float64(f.spams)
414 }
415 if f.hams > 0 {
416 wH = float64(c.Ham) / float64(f.hams)
417 }
418 r := wS / (wS + wH)
419
420 if r < f.MaxPower {
421 r = f.MaxPower
422 } else if r >= 1-f.MaxPower {
423 r = 1 - f.MaxPower
424 }
425
426 if c.Ham+c.Spam <= uint32(f.RareWords) {
427 // Reduce the power of rare words.
428 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
429 }
430 if math.Abs(0.5-r) < f.IgnoreWords {
431 continue
432 }
433 if r < 0.5 {
434 if len(topHam) >= f.TopWords && r > hamHigh {
435 continue
436 }
437 topHam = append(topHam, WordScore{w, r})
438 if r > hamHigh {
439 hamHigh = r
440 }
441 } else if r > 0.5 {
442 if len(topSpam) >= f.TopWords && r < spamLow {
443 continue
444 }
445 topSpam = append(topSpam, WordScore{w, r})
446 if r < spamLow {
447 spamLow = r
448 }
449 }
450 }
451
452 sort.Slice(topHam, func(i, j int) bool {
453 a, b := topHam[i], topHam[j]
454 if a.Score == b.Score {
455 return len(a.Word) > len(b.Word)
456 }
457 return a.Score < b.Score
458 })
459 sort.Slice(topSpam, func(i, j int) bool {
460 a, b := topSpam[i], topSpam[j]
461 if a.Score == b.Score {
462 return len(a.Word) > len(b.Word)
463 }
464 return a.Score > b.Score
465 })
466
467 nham := f.TopWords
468 if nham > len(topHam) {
469 nham = len(topHam)
470 }
471 nspam := f.TopWords
472 if nspam > len(topSpam) {
473 nspam = len(topSpam)
474 }
475 topHam = topHam[:nham]
476 topSpam = topSpam[:nspam]
477
478 var eta float64
479 for _, x := range topHam {
480 eta += math.Log(1-x.Score) - math.Log(x.Score)
481 }
482 for _, x := range topSpam {
483 eta += math.Log(1-x.Score) - math.Log(x.Score)
484 }
485
486 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
487
488 prob := 1 / (1 + math.Pow(math.E, eta))
489 return prob, topHam, topSpam, nil
490}
491
492// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
493func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) {
494 if f.closed {
495 return 0, nil, nil, nil, errClosed
496 }
497
498 mf, err := os.Open(path)
499 if err != nil {
500 return 0, nil, nil, nil, err
501 }
502 defer func() {
503 err := mf.Close()
504 f.log.Check(err, "closing file after classify")
505 }()
506 fi, err := mf.Stat()
507 if err != nil {
508 return 0, nil, nil, nil, err
509 }
510 return f.ClassifyMessageReader(ctx, mf, fi.Size())
511}
512
513func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) {
514 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
515 if err != nil && errors.Is(err, message.ErrBadContentType) {
516 // Invalid content-type header is a sure sign of spam.
517 //f.log.Infox("parsing content", err)
518 return 1, nil, nil, nil, nil
519 }
520 return f.ClassifyMessage(ctx, m)
521}
522
523// ClassifyMessage parses the mail message in r and returns the spam probability
524// (between 0 and 1), along with the tokenized words found in the message, and the
525// ham and spam words and their scores used.
526func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) {
527 var err error
528 words, err = f.ParseMessage(m)
529 if err != nil {
530 return 0, nil, nil, nil, err
531 }
532
533 probability, hams, spams, err = f.ClassifyWords(ctx, words)
534 return probability, words, hams, spams, err
535}
536
537// Train adds the words of a single message to the filter.
538func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
539 if err := f.ensureBloom(); err != nil {
540 return err
541 }
542
543 var lwords []string
544
545 for w := range words {
546 if !f.bloom.Has(w) {
547 f.bloom.Add(w)
548 continue
549 }
550 if _, ok := f.cache[w]; !ok {
551 lwords = append(lwords, w)
552 }
553 }
554
555 if err := f.loadCache(ctx, lwords); err != nil {
556 return err
557 }
558
559 f.modified = true
560 if ham {
561 f.hams++
562 } else {
563 f.spams++
564 }
565
566 for w := range words {
567 c := f.cache[w]
568 if ham {
569 c.Ham++
570 } else {
571 c.Spam++
572 }
573 f.cache[w] = c
574 f.changed[w] = c
575 }
576 return nil
577}
578
579func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
580 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
581 words, err := f.ParseMessage(p)
582 if err != nil {
583 return fmt.Errorf("parsing mail contents: %v", err)
584 }
585 return f.Train(ctx, ham, words)
586}
587
588func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
589 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
590 words, err := f.ParseMessage(p)
591 if err != nil {
592 return fmt.Errorf("parsing mail contents: %v", err)
593 }
594 return f.Untrain(ctx, ham, words)
595}
596
597func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
598 if len(lwords) == 0 {
599 return nil
600 }
601
602 return loadWords(ctx, f.db, lwords, f.cache)
603}
604
605// Untrain adjusts the filter to undo a previous training of the words.
606func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
607 if err := f.ensureBloom(); err != nil {
608 return err
609 }
610
611 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
612 var lwords []string
613 for w := range words {
614 if _, ok := f.cache[w]; !ok {
615 lwords = append(lwords, w)
616 }
617 }
618 if err := f.loadCache(ctx, lwords); err != nil {
619 return err
620 }
621
622 // Modify the message count.
623 f.modified = true
624 if ham {
625 f.hams--
626 } else {
627 f.spams--
628 }
629
630 // Decrease the word counts.
631 for w := range words {
632 c, ok := f.cache[w]
633 if !ok {
634 continue
635 }
636 if ham {
637 c.Ham--
638 } else {
639 c.Spam--
640 }
641 f.cache[w] = c
642 f.changed[w] = c
643 }
644 return nil
645}
646
647// TrainDir parses mail messages from files and trains the filter.
648func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
649 if f.closed {
650 return 0, 0, errClosed
651 }
652 if err := f.ensureBloom(); err != nil {
653 return 0, 0, err
654 }
655
656 for _, name := range files {
657 p := filepath.Join(dir, name)
658 valid, words, err := f.tokenizeMail(p)
659 if err != nil {
660 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
661 malformed++
662 continue
663 }
664 if !valid {
665 continue
666 }
667 n++
668 for w := range words {
669 if !f.bloom.Has(w) {
670 f.bloom.Add(w)
671 continue
672 }
673 c := f.cache[w]
674 f.modified = true
675 if ham {
676 c.Ham++
677 } else {
678 c.Spam++
679 }
680 f.cache[w] = c
681 f.changed[w] = c
682 }
683 }
684 return
685}
686
687// TrainDirs trains and saves a filter with mail messages from different types
688// of directories.
689func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
690 if f.closed {
691 return errClosed
692 }
693
694 var err error
695
696 var start time.Time
697 var hamMalformed, sentMalformed, spamMalformed uint32
698
699 start = time.Now()
700 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
701 if err != nil {
702 return err
703 }
704 tham := time.Since(start)
705
706 var sent uint32
707 start = time.Now()
708 if sentDir != "" {
709 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
710 if err != nil {
711 return err
712 }
713 }
714 tsent := time.Since(start)
715
716 start = time.Now()
717 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
718 if err != nil {
719 return err
720 }
721 tspam := time.Since(start)
722
723 hams := f.hams
724 f.hams += sent
725 if err := f.Save(); err != nil {
726 return fmt.Errorf("saving filter: %s", err)
727 }
728
729 dbSize := f.fileSize(f.dbPath)
730 bloomSize := f.fileSize(f.bloomPath)
731
732 f.log.Print("training done",
733 slog.Any("hams", hams),
734 slog.Any("hamtime", tham),
735 slog.Any("hammalformed", hamMalformed),
736 slog.Any("sent", sent),
737 slog.Any("senttime", tsent),
738 slog.Any("sentmalformed", sentMalformed),
739 slog.Any("spams", f.spams),
740 slog.Any("spamtime", tspam),
741 slog.Any("spammalformed", spamMalformed),
742 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
743 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
744 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
745 )
746
747 return nil
748}
749
750func (f *Filter) fileSize(p string) int {
751 fi, err := os.Stat(p)
752 if err != nil {
753 f.log.Infox("stat", err, slog.Any("path", p))
754 return 0
755 }
756 return int(fi.Size())
757}
758
759// DB returns the database, for backups.
760func (f *Filter) DB() *bstore.DB {
761 return f.db
762}
763