1// Package junk implements a bayesian spam filter.
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
25 "github.com/mjl-/bstore"
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
32 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
33 errClosed = errors.New("filter is closed")
41type wordscore struct {
47// Params holds parameters for the filter. Most are at test-time. The first are
48// used during parsing and training.
50 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
51 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
52 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
53 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
54 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
55 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
56 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
59var DBTypes = []any{wordscore{}} // Stored in DB.
64 log mlog.Log // For logging cid.
66 modified bool // Whether any modifications are pending. Cleared by Save.
67 hams, spams uint32 // Message count, stored in db under word "-".
68 cache map[string]word // Words read from database or during training.
69 changed map[string]word // Words modified during training.
70 dbPath, bloomPath string
71 db *bstore.DB // Always open on a filter.
72 bloom *Bloom // Only opened when writing.
73 isNew bool // Set for new filters until their first sync to disk. For faster writing.
76func (f *Filter) ensureBloom() error {
81 f.bloom, err = openBloom(f.bloomPath)
85// CloseDiscard closes the filter, discarding any changes.
86func (f *Filter) CloseDiscard() error {
91 *f = Filter{log: f.log, closed: true}
95// Close first saves the filter if it has modifications, then closes the database
96// connection and releases the bloom filter.
97func (f *Filter) Close() error {
110 *f = Filter{log: f.log, closed: true}
114func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
118 bloom, err = openBloom(bloomPath)
122 } else if fi, err := os.Stat(bloomPath); err == nil {
123 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
124 return nil, fmt.Errorf("bloom: %s", err)
128 db, err := openDB(ctx, log, dbPath)
130 return nil, fmt.Errorf("open database: %s", err)
136 cache: map[string]word{},
137 changed: map[string]word{},
139 bloomPath: bloomPath,
143 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
144 wc := wordscore{Word: "-"}
152 log.Check(cerr, "closing filter after error")
153 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
158// NewFilter creates a new filter with empty bloom filter and database files. The
159// filter is marked as new until the first save, will be done automatically if
160// TrainDirs is called. If the bloom and/or database files exist, an error is
162func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
164 if _, err := os.Stat(bloomPath); err == nil {
165 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
166 } else if _, err := os.Stat(dbPath); err == nil {
167 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
170 bloomSizeBytes := 4 * 1024 * 1024
171 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
172 return nil, fmt.Errorf("bloom: %s", err)
174 bf, err := os.Create(bloomPath)
176 return nil, fmt.Errorf("creating bloom file: %w", err)
178 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
180 log.Check(xerr, "closing bloom filter file after truncate error")
181 xerr = os.Remove(bloomPath)
182 log.Check(xerr, "removing bloom filter file after truncate error")
183 return nil, fmt.Errorf("making empty bloom filter: %s", err)
186 log.Check(err, "closing bloomfilter file")
188 db, err := newDB(ctx, log, dbPath)
190 xerr := os.Remove(bloomPath)
191 log.Check(xerr, "removing bloom filter file after db init error")
192 xerr = os.Remove(dbPath)
193 log.Check(xerr, "removing database file after db init error")
194 return nil, fmt.Errorf("open database: %s", err)
197 words := map[string]word{} // f.changed is set to new map after training
201 modified: true, // Ensure ham/spam message count is added for new filter.
205 bloomPath: bloomPath,
214func openBloom(path string) (*Bloom, error) {
215 buf, err := os.ReadFile(path)
217 return nil, fmt.Errorf("reading bloom file: %w", err)
219 return NewBloom(buf, bloomK)
222func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
223 // Remove any existing files.
228 err := os.Remove(path)
229 log.Check(err, "removing db file after init error")
233 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
234 db, err := bstore.Open(ctx, path, &opts, DBTypes...)
236 return nil, fmt.Errorf("open new database: %w", err)
241func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
242 if _, err := os.Stat(path); err != nil {
243 return nil, fmt.Errorf("stat db file: %w", err)
245 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
246 return bstore.Open(ctx, path, &opts, DBTypes...)
249// Save stores modifications, e.g. from training, to the database and bloom
251func (f *Filter) Save() error {
259 if f.bloom != nil && f.bloom.Modified() {
260 if err := f.bloom.Write(f.bloomPath); err != nil {
261 return fmt.Errorf("writing bloom filter: %w", err)
265 // We need to insert sequentially for reasonable performance.
266 words := make([]string, len(f.changed))
268 for w := range f.changed {
272 sort.Slice(words, func(i, j int) bool {
273 return words[i] < words[j]
276 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
277 // start := time.Now()
279 if err := f.db.HintAppend(true, wordscore{}); err != nil {
280 f.log.Errorx("hint appendonly", err)
283 err := f.db.HintAppend(false, wordscore{})
284 f.log.Check(err, "restoring append hint")
288 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
289 update := func(w string, ham, spam uint32) error {
291 return tx.Insert(&wordscore{w, ham, spam})
294 wc := wordscore{w, 0, 0}
296 if err == bstore.ErrAbsent {
297 return tx.Insert(&wordscore{w, ham, spam})
298 } else if err != nil {
301 return tx.Update(&wordscore{w, ham, spam})
303 if err := update("-", f.hams, f.spams); err != nil {
304 return fmt.Errorf("storing total ham/spam message count: %s", err)
307 for _, w := range words {
309 if err := update(w, c.Ham, c.Spam); err != nil {
310 return fmt.Errorf("updating ham/spam count: %s", err)
316 return fmt.Errorf("updating database: %w", err)
319 f.changed = map[string]word{}
322 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
326func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
327 sort.Slice(l, func(i, j int) bool {
331 err := db.Read(ctx, func(tx *bstore.Tx) error {
332 for _, w := range l {
333 wc := wordscore{Word: w}
334 if err := tx.Get(&wc); err == nil {
335 dst[w] = word{wc.Ham, wc.Spam}
341 return fmt.Errorf("fetching words: %s", err)
346// WordScore is a word with its score as used in classifications, based on
347// (historic) training.
348type WordScore struct {
350 Score float64 // 0 is ham, 1 is spam.
353// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
354func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, hams, spams []WordScore, rerr error) {
356 return 0, nil, nil, errClosed
359 var hamHigh float64 = 0
360 var spamLow float64 = 1
361 var topHam []WordScore
362 var topSpam []WordScore
364 // Find words that should be in the database.
365 lookupWords := []string{}
366 expect := map[string]struct{}{}
367 unknowns := map[string]struct{}{}
369 for w := range words {
370 if f.bloom != nil && !f.bloom.Has(w) {
372 if len(unknowns) < 50 {
373 unknowns[w] = struct{}{}
377 if _, ok := f.cache[w]; ok {
380 lookupWords = append(lookupWords, w)
381 expect[w] = struct{}{}
383 if len(unknowns) > 0 {
384 f.log.Debug("unknown words in bloom filter, showing max 50",
385 slog.Any("words", unknowns),
386 slog.Any("totalunknown", totalUnknown),
387 slog.Any("totalwords", len(words)))
390 // Fetch words from database.
391 fetched := map[string]word{}
392 if len(lookupWords) > 0 {
393 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
394 return 0, nil, nil, err
396 for w, c := range fetched {
400 f.log.Debug("unknown words in db",
401 slog.Any("words", expect),
402 slog.Any("totalunknown", len(expect)),
403 slog.Any("totalwords", len(words)))
406 for w := range words {
413 wS = float64(c.Spam) / float64(f.spams)
416 wH = float64(c.Ham) / float64(f.hams)
422 } else if r >= 1-f.MaxPower {
426 if c.Ham+c.Spam <= uint32(f.RareWords) {
427 // Reduce the power of rare words.
428 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
430 if math.Abs(0.5-r) < f.IgnoreWords {
434 if len(topHam) >= f.TopWords && r > hamHigh {
437 topHam = append(topHam, WordScore{w, r})
442 if len(topSpam) >= f.TopWords && r < spamLow {
445 topSpam = append(topSpam, WordScore{w, r})
452 sort.Slice(topHam, func(i, j int) bool {
453 a, b := topHam[i], topHam[j]
454 if a.Score == b.Score {
455 return len(a.Word) > len(b.Word)
457 return a.Score < b.Score
459 sort.Slice(topSpam, func(i, j int) bool {
460 a, b := topSpam[i], topSpam[j]
461 if a.Score == b.Score {
462 return len(a.Word) > len(b.Word)
464 return a.Score > b.Score
468 if nham > len(topHam) {
472 if nspam > len(topSpam) {
475 topHam = topHam[:nham]
476 topSpam = topSpam[:nspam]
479 for _, x := range topHam {
480 eta += math.Log(1-x.Score) - math.Log(x.Score)
482 for _, x := range topSpam {
483 eta += math.Log(1-x.Score) - math.Log(x.Score)
486 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
488 prob := 1 / (1 + math.Pow(math.E, eta))
489 return prob, topHam, topSpam, nil
492// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
493func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) {
495 return 0, nil, nil, nil, errClosed
498 mf, err := os.Open(path)
500 return 0, nil, nil, nil, err
504 f.log.Check(err, "closing file after classify")
508 return 0, nil, nil, nil, err
510 return f.ClassifyMessageReader(ctx, mf, fi.Size())
513func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) {
514 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
515 if err != nil && errors.Is(err, message.ErrBadContentType) {
516 // Invalid content-type header is a sure sign of spam.
517 //f.log.Infox("parsing content", err)
518 return 1, nil, nil, nil, nil
520 return f.ClassifyMessage(ctx, m)
523// ClassifyMessage parses the mail message in r and returns the spam probability
524// (between 0 and 1), along with the tokenized words found in the message, and the
525// ham and spam words and their scores used.
526func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) {
528 words, err = f.ParseMessage(m)
530 return 0, nil, nil, nil, err
533 probability, hams, spams, err = f.ClassifyWords(ctx, words)
534 return probability, words, hams, spams, err
537// Train adds the words of a single message to the filter.
538func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
539 if err := f.ensureBloom(); err != nil {
545 for w := range words {
550 if _, ok := f.cache[w]; !ok {
551 lwords = append(lwords, w)
555 if err := f.loadCache(ctx, lwords); err != nil {
566 for w := range words {
579func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
580 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
581 words, err := f.ParseMessage(p)
583 return fmt.Errorf("parsing mail contents: %v", err)
585 return f.Train(ctx, ham, words)
588func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
589 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
590 words, err := f.ParseMessage(p)
592 return fmt.Errorf("parsing mail contents: %v", err)
594 return f.Untrain(ctx, ham, words)
597func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
598 if len(lwords) == 0 {
602 return loadWords(ctx, f.db, lwords, f.cache)
605// Untrain adjusts the filter to undo a previous training of the words.
606func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
607 if err := f.ensureBloom(); err != nil {
611 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
613 for w := range words {
614 if _, ok := f.cache[w]; !ok {
615 lwords = append(lwords, w)
618 if err := f.loadCache(ctx, lwords); err != nil {
622 // Modify the message count.
631 f.log.Error("attempt to decrease ham/spam message count while already zero", slog.Bool("ham", ham))
636 // Decrease the word counts.
637 for w := range words {
649 f.log.Error("attempt to decrease ham/spam word count while already zero", slog.String("word", w), slog.Bool("ham", ham))
659// TrainDir parses mail messages from files and trains the filter.
660func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
662 return 0, 0, errClosed
664 if err := f.ensureBloom(); err != nil {
668 for _, name := range files {
669 p := filepath.Join(dir, name)
670 valid, words, err := f.tokenizeMail(p)
672 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
680 for w := range words {
699// TrainDirs trains and saves a filter with mail messages from different types
701func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
709 var hamMalformed, sentMalformed, spamMalformed uint32
712 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
716 tham := time.Since(start)
721 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
726 tsent := time.Since(start)
729 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
733 tspam := time.Since(start)
737 if err := f.Save(); err != nil {
738 return fmt.Errorf("saving filter: %s", err)
741 dbSize := f.fileSize(f.dbPath)
742 bloomSize := f.fileSize(f.bloomPath)
744 f.log.Print("training done",
745 slog.Any("hams", hams),
746 slog.Any("hamtime", tham),
747 slog.Any("hammalformed", hamMalformed),
748 slog.Any("sent", sent),
749 slog.Any("senttime", tsent),
750 slog.Any("sentmalformed", sentMalformed),
751 slog.Any("spams", f.spams),
752 slog.Any("spamtime", tspam),
753 slog.Any("spammalformed", spamMalformed),
754 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
755 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
756 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
762func (f *Filter) fileSize(p string) int {
763 fi, err := os.Stat(p)
765 f.log.Infox("stat", err, slog.Any("path", p))
768 return int(fi.Size())
771// DB returns the database, for backups.
772func (f *Filter) DB() *bstore.DB {