1// Package junk implements a bayesian spam filter.
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
25 "github.com/mjl-/bstore"
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
29 "github.com/mjl-/mox/moxvar"
34 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
35 errClosed = errors.New("filter is closed")
43type Wordscore struct {
44 Word string `bstore:"typename wordscore"`
49// Params holds parameters for the filter. Most are at test-time. The first are
50// used during parsing and training.
52 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
53 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
54 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
55 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
56 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
57 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
58 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
61var DBTypes = []any{Wordscore{}} // Stored in DB.
66 log mlog.Log // For logging cid.
68 modified bool // Whether any modifications are pending. Cleared by Save.
69 hams, spams uint32 // Message count, stored in db under word "-".
70 cache map[string]word // Words read from database or during training.
71 changed map[string]word // Words modified during training.
72 dbPath, bloomPath string
73 db *bstore.DB // Always open on a filter.
74 bloom *Bloom // Only opened when writing.
75 isNew bool // Set for new filters until their first sync to disk. For faster writing.
78func (f *Filter) ensureBloom() error {
83 f.bloom, err = openBloom(f.log, f.bloomPath)
87// CloseDiscard closes the filter, discarding any changes.
88func (f *Filter) CloseDiscard() error {
93 *f = Filter{log: f.log, closed: true}
97// Close first saves the filter if it has modifications, then closes the database
98// connection and releases the bloom filter.
99func (f *Filter) Close() error {
109 f.log.Check(xerr, "closing junk filter after error")
113 *f = Filter{log: f.log, closed: true}
117func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
121 bloom, err = openBloom(log, bloomPath)
125 } else if fi, err := os.Stat(bloomPath); err == nil {
126 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
127 return nil, fmt.Errorf("bloom: %s", err)
131 db, err := openDB(ctx, log, dbPath)
133 return nil, fmt.Errorf("open database: %s", err)
139 cache: map[string]word{},
140 changed: map[string]word{},
142 bloomPath: bloomPath,
146 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
147 wc := Wordscore{Word: "-"}
155 log.Check(cerr, "closing filter after error")
156 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
161// NewFilter creates a new filter with empty bloom filter and database files. The
162// filter is marked as new until the first save, will be done automatically if
163// TrainDirs is called. If the bloom and/or database files exist, an error is
165func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
167 if _, err := os.Stat(bloomPath); err == nil {
168 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
169 } else if _, err := os.Stat(dbPath); err == nil {
170 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
173 bloomSizeBytes := 4 * 1024 * 1024
174 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
175 return nil, fmt.Errorf("bloom: %s", err)
177 bf, err := os.Create(bloomPath)
179 return nil, fmt.Errorf("creating bloom file: %w", err)
181 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
183 log.Check(xerr, "closing bloom filter file after truncate error")
184 xerr = os.Remove(bloomPath)
185 log.Check(xerr, "removing bloom filter file after truncate error")
186 return nil, fmt.Errorf("making empty bloom filter: %s", err)
189 log.Check(err, "closing bloomfilter file")
191 db, err := newDB(ctx, log, dbPath)
193 xerr := os.Remove(bloomPath)
194 log.Check(xerr, "removing bloom filter file after db init error")
195 xerr = os.Remove(dbPath)
196 log.Check(xerr, "removing database file after db init error")
197 return nil, fmt.Errorf("open database: %s", err)
200 words := map[string]word{} // f.changed is set to new map after training
204 modified: true, // Ensure ham/spam message count is added for new filter.
208 bloomPath: bloomPath,
217func openBloom(log mlog.Log, path string) (*Bloom, error) {
218 buf, err := os.ReadFile(path)
220 return nil, fmt.Errorf("reading bloom file: %w", err)
222 return NewBloom(log, buf, bloomK)
225func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
226 // Remove any existing files.
231 err := os.Remove(path)
232 log.Check(err, "removing db file after init error")
236 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: moxvar.RegisterLogger(path, log.Logger)}
237 db, err := bstore.Open(ctx, path, &opts, DBTypes...)
239 return nil, fmt.Errorf("open new database: %w", err)
244func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
245 if _, err := os.Stat(path); err != nil {
246 return nil, fmt.Errorf("stat db file: %w", err)
248 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
249 return bstore.Open(ctx, path, &opts, DBTypes...)
252// Save stores modifications, e.g. from training, to the database and bloom
254func (f *Filter) Save() error {
262 if f.bloom != nil && f.bloom.Modified() {
263 if err := f.bloom.Write(f.bloomPath); err != nil {
264 return fmt.Errorf("writing bloom filter: %w", err)
268 // We need to insert sequentially for reasonable performance.
269 words := make([]string, len(f.changed))
271 for w := range f.changed {
277 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
278 // start := time.Now()
280 if err := f.db.HintAppend(true, Wordscore{}); err != nil {
281 f.log.Errorx("hint appendonly", err)
284 err := f.db.HintAppend(false, Wordscore{})
285 f.log.Check(err, "restoring append hint")
289 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
290 update := func(w string, ham, spam uint32) error {
291 zeroword := w != "-" && ham == 0 && spam == 0
297 return tx.Insert(&Wordscore{w, ham, spam})
300 wc := Wordscore{w, 0, 0}
302 if err == bstore.ErrAbsent {
306 return tx.Insert(&Wordscore{w, ham, spam})
307 } else if err != nil {
311 return tx.Delete(&Wordscore{w, 0, 0})
313 return tx.Update(&Wordscore{w, ham, spam})
315 if err := update("-", f.hams, f.spams); err != nil {
316 return fmt.Errorf("storing total ham/spam message count: %s", err)
319 for _, w := range words {
321 if err := update(w, c.Ham, c.Spam); err != nil {
322 return fmt.Errorf("updating ham/spam count: %s", err)
328 return fmt.Errorf("updating database: %w", err)
331 f.changed = map[string]word{}
334 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
338func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
341 err := db.Read(ctx, func(tx *bstore.Tx) error {
342 for _, w := range l {
343 wc := Wordscore{Word: w}
344 if err := tx.Get(&wc); err == nil {
345 dst[w] = word{wc.Ham, wc.Spam}
346 } else if err != bstore.ErrAbsent {
347 return fmt.Errorf("get word: %v", err)
353 return fmt.Errorf("fetching words: %s", err)
358// WordScore is a word with its score as used in classifications, based on
359// (historic) training.
360type WordScore struct {
362 Score float64 // 0 is ham, 1 is spam.
365// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
366func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (Result, error) {
368 return Result{}, errClosed
371 var hamHigh float64 = 0
372 var spamLow float64 = 1
373 var topHam []WordScore
374 var topSpam []WordScore
376 // Find words that should be in the database.
377 lookupWords := []string{}
378 expect := map[string]struct{}{}
379 unknowns := map[string]struct{}{}
381 for w := range words {
382 if f.bloom != nil && !f.bloom.Has(w) {
384 if len(unknowns) < 50 {
385 unknowns[w] = struct{}{}
389 if _, ok := f.cache[w]; ok {
392 lookupWords = append(lookupWords, w)
393 expect[w] = struct{}{}
395 if len(unknowns) > 0 {
396 f.log.Debug("unknown words in bloom filter, showing max 50",
397 slog.Any("words", unknowns),
398 slog.Any("totalunknown", totalUnknown),
399 slog.Any("totalwords", len(words)))
402 // Fetch words from database.
403 fetched := map[string]word{}
404 if len(lookupWords) > 0 {
405 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
408 for w, c := range fetched {
412 f.log.Debug("unknown words in db",
413 slog.Any("words", expect),
414 slog.Any("totalunknown", len(expect)),
415 slog.Any("totalwords", len(words)))
418 for w := range words {
425 wS = float64(c.Spam) / float64(f.spams)
428 wH = float64(c.Ham) / float64(f.hams)
434 } else if r >= 1-f.MaxPower {
438 if c.Ham+c.Spam <= uint32(f.RareWords) {
439 // Reduce the power of rare words.
440 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
442 if math.Abs(0.5-r) < f.IgnoreWords {
446 if len(topHam) >= f.TopWords && r > hamHigh {
449 topHam = append(topHam, WordScore{w, r})
454 if len(topSpam) >= f.TopWords && r < spamLow {
457 topSpam = append(topSpam, WordScore{w, r})
464 sort.Slice(topHam, func(i, j int) bool {
465 a, b := topHam[i], topHam[j]
466 if a.Score == b.Score {
467 return len(a.Word) > len(b.Word)
469 return a.Score < b.Score
471 sort.Slice(topSpam, func(i, j int) bool {
472 a, b := topSpam[i], topSpam[j]
473 if a.Score == b.Score {
474 return len(a.Word) > len(b.Word)
476 return a.Score > b.Score
479 nham := min(f.TopWords, len(topHam))
480 nspam := min(f.TopWords, len(topSpam))
481 topHam = topHam[:nham]
482 topSpam = topSpam[:nspam]
485 for _, x := range topHam {
486 eta += math.Log(1-x.Score) - math.Log(x.Score)
488 for _, x := range topSpam {
489 eta += math.Log(1-x.Score) - math.Log(x.Score)
492 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
494 prob := 1 / (1 + math.Pow(math.E, eta))
496 // We want at least some positive signals, otherwise a few negative signals can
497 // mark incoming messages as spam too easily. If we have no negative signals, more
498 // messages will be classified as ham and accepted. This is fine, the user will
499 // classify it such, and retrain the filter. We mostly want to avoid rejecting too
500 // much when there isn't enough signal.
501 significant := f.hams >= 50
503 return Result{prob, significant, words, topHam, topSpam}, nil
506// Result is a successful classification, whether positive or negative.
508 Probability float64 // Between 0 (ham) and 1 (spam).
509 Significant bool // If true, enough classified words are available to base decisions on.
510 Words map[string]struct{}
511 Hams, Spams []WordScore
514// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
515func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (Result, error) {
517 return Result{}, errClosed
520 mf, err := os.Open(path)
526 f.log.Check(err, "closing file after classify")
532 return f.ClassifyMessageReader(ctx, mf, fi.Size())
535func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (Result, error) {
536 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
537 if err != nil && errors.Is(err, message.ErrBadContentType) {
538 // Invalid content-type header is a sure sign of spam.
539 //f.log.Infox("parsing content", err)
540 return Result{Probability: 1, Significant: true}, nil
542 return f.ClassifyMessage(ctx, m)
545// ClassifyMessage parses the mail message in r and returns the spam probability
546// (between 0 and 1), along with the tokenized words found in the message, and the
547// ham and spam words and their scores used.
548func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (Result, error) {
549 words, err := f.ParseMessage(m)
553 return f.ClassifyWords(ctx, words)
556// Train adds the words of a single message to the filter.
557func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
558 if err := f.ensureBloom(); err != nil {
564 for w := range words {
569 if _, ok := f.cache[w]; !ok {
570 lwords = append(lwords, w)
574 if err := f.loadCache(ctx, lwords); err != nil {
585 for w := range words {
598func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
599 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
600 words, err := f.ParseMessage(p)
602 return fmt.Errorf("parsing mail contents: %v", err)
604 return f.Train(ctx, ham, words)
607func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
608 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
609 words, err := f.ParseMessage(p)
611 return fmt.Errorf("parsing mail contents: %v", err)
613 return f.Untrain(ctx, ham, words)
616func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
617 if len(lwords) == 0 {
621 return loadWords(ctx, f.db, lwords, f.cache)
624// Untrain adjusts the filter to undo a previous training of the words.
625func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
626 if err := f.ensureBloom(); err != nil {
630 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
632 for w := range words {
633 if _, ok := f.cache[w]; !ok {
634 lwords = append(lwords, w)
637 if err := f.loadCache(ctx, lwords); err != nil {
641 // Modify the message count.
650 f.log.Error("attempt to decrease ham/spam message count while already zero", slog.Bool("ham", ham))
655 // Decrease the word counts.
656 for w := range words {
668 f.log.Error("attempt to decrease ham/spam word count while already zero", slog.String("word", w), slog.Bool("ham", ham))
678// TrainDir parses mail messages from files and trains the filter.
679func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
681 return 0, 0, errClosed
683 if err := f.ensureBloom(); err != nil {
687 for _, name := range files {
688 p := filepath.Join(dir, name)
689 valid, words, err := f.tokenizeMail(p)
691 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
699 for w := range words {
718// TrainDirs trains and saves a filter with mail messages from different types
720func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
728 var hamMalformed, sentMalformed, spamMalformed uint32
731 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
735 tham := time.Since(start)
740 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
745 tsent := time.Since(start)
748 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
752 tspam := time.Since(start)
756 if err := f.Save(); err != nil {
757 return fmt.Errorf("saving filter: %s", err)
760 dbSize := f.fileSize(f.dbPath)
761 bloomSize := f.fileSize(f.bloomPath)
763 f.log.Print("training done",
764 slog.Any("hams", hams),
765 slog.Any("hamtime", tham),
766 slog.Any("hammalformed", hamMalformed),
767 slog.Any("sent", sent),
768 slog.Any("senttime", tsent),
769 slog.Any("sentmalformed", sentMalformed),
770 slog.Any("spams", f.spams),
771 slog.Any("spamtime", tspam),
772 slog.Any("spammalformed", spamMalformed),
773 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
774 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
775 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
781func (f *Filter) fileSize(p string) int {
782 fi, err := os.Stat(p)
784 f.log.Infox("stat", err, slog.Any("path", p))
787 return int(fi.Size())
790// DB returns the database, for backups.
791func (f *Filter) DB() *bstore.DB {