1// Package junk implements a bayesian spam filter.
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
25 "github.com/mjl-/bstore"
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
29 "github.com/mjl-/mox/moxvar"
33 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
34 errClosed = errors.New("filter is closed")
42type wordscore struct {
48// Params holds parameters for the filter. Most are at test-time. The first are
49// used during parsing and training.
51 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
52 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
53 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
54 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
55 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
56 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
57 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
60var DBTypes = []any{wordscore{}} // Stored in DB.
65 log mlog.Log // For logging cid.
67 modified bool // Whether any modifications are pending. Cleared by Save.
68 hams, spams uint32 // Message count, stored in db under word "-".
69 cache map[string]word // Words read from database or during training.
70 changed map[string]word // Words modified during training.
71 dbPath, bloomPath string
72 db *bstore.DB // Always open on a filter.
73 bloom *Bloom // Only opened when writing.
74 isNew bool // Set for new filters until their first sync to disk. For faster writing.
77func (f *Filter) ensureBloom() error {
82 f.bloom, err = openBloom(f.bloomPath)
86// CloseDiscard closes the filter, discarding any changes.
87func (f *Filter) CloseDiscard() error {
92 *f = Filter{log: f.log, closed: true}
96// Close first saves the filter if it has modifications, then closes the database
97// connection and releases the bloom filter.
98func (f *Filter) Close() error {
111 *f = Filter{log: f.log, closed: true}
115func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
119 bloom, err = openBloom(bloomPath)
123 } else if fi, err := os.Stat(bloomPath); err == nil {
124 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
125 return nil, fmt.Errorf("bloom: %s", err)
129 db, err := openDB(ctx, log, dbPath)
131 return nil, fmt.Errorf("open database: %s", err)
137 cache: map[string]word{},
138 changed: map[string]word{},
140 bloomPath: bloomPath,
144 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
145 wc := wordscore{Word: "-"}
153 log.Check(cerr, "closing filter after error")
154 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
159// NewFilter creates a new filter with empty bloom filter and database files. The
160// filter is marked as new until the first save, will be done automatically if
161// TrainDirs is called. If the bloom and/or database files exist, an error is
163func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
165 if _, err := os.Stat(bloomPath); err == nil {
166 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
167 } else if _, err := os.Stat(dbPath); err == nil {
168 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
171 bloomSizeBytes := 4 * 1024 * 1024
172 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
173 return nil, fmt.Errorf("bloom: %s", err)
175 bf, err := os.Create(bloomPath)
177 return nil, fmt.Errorf("creating bloom file: %w", err)
179 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
181 log.Check(xerr, "closing bloom filter file after truncate error")
182 xerr = os.Remove(bloomPath)
183 log.Check(xerr, "removing bloom filter file after truncate error")
184 return nil, fmt.Errorf("making empty bloom filter: %s", err)
187 log.Check(err, "closing bloomfilter file")
189 db, err := newDB(ctx, log, dbPath)
191 xerr := os.Remove(bloomPath)
192 log.Check(xerr, "removing bloom filter file after db init error")
193 xerr = os.Remove(dbPath)
194 log.Check(xerr, "removing database file after db init error")
195 return nil, fmt.Errorf("open database: %s", err)
198 words := map[string]word{} // f.changed is set to new map after training
202 modified: true, // Ensure ham/spam message count is added for new filter.
206 bloomPath: bloomPath,
215func openBloom(path string) (*Bloom, error) {
216 buf, err := os.ReadFile(path)
218 return nil, fmt.Errorf("reading bloom file: %w", err)
220 return NewBloom(buf, bloomK)
223func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
224 // Remove any existing files.
229 err := os.Remove(path)
230 log.Check(err, "removing db file after init error")
234 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: moxvar.RegisterLogger(path, log.Logger)}
235 db, err := bstore.Open(ctx, path, &opts, DBTypes...)
237 return nil, fmt.Errorf("open new database: %w", err)
242func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
243 if _, err := os.Stat(path); err != nil {
244 return nil, fmt.Errorf("stat db file: %w", err)
246 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
247 return bstore.Open(ctx, path, &opts, DBTypes...)
250// Save stores modifications, e.g. from training, to the database and bloom
252func (f *Filter) Save() error {
260 if f.bloom != nil && f.bloom.Modified() {
261 if err := f.bloom.Write(f.bloomPath); err != nil {
262 return fmt.Errorf("writing bloom filter: %w", err)
266 // We need to insert sequentially for reasonable performance.
267 words := make([]string, len(f.changed))
269 for w := range f.changed {
273 sort.Slice(words, func(i, j int) bool {
274 return words[i] < words[j]
277 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
278 // start := time.Now()
280 if err := f.db.HintAppend(true, wordscore{}); err != nil {
281 f.log.Errorx("hint appendonly", err)
284 err := f.db.HintAppend(false, wordscore{})
285 f.log.Check(err, "restoring append hint")
289 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
290 update := func(w string, ham, spam uint32) error {
292 return tx.Insert(&wordscore{w, ham, spam})
295 wc := wordscore{w, 0, 0}
297 if err == bstore.ErrAbsent {
298 return tx.Insert(&wordscore{w, ham, spam})
299 } else if err != nil {
302 return tx.Update(&wordscore{w, ham, spam})
304 if err := update("-", f.hams, f.spams); err != nil {
305 return fmt.Errorf("storing total ham/spam message count: %s", err)
308 for _, w := range words {
310 if err := update(w, c.Ham, c.Spam); err != nil {
311 return fmt.Errorf("updating ham/spam count: %s", err)
317 return fmt.Errorf("updating database: %w", err)
320 f.changed = map[string]word{}
323 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
327func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
328 sort.Slice(l, func(i, j int) bool {
332 err := db.Read(ctx, func(tx *bstore.Tx) error {
333 for _, w := range l {
334 wc := wordscore{Word: w}
335 if err := tx.Get(&wc); err == nil {
336 dst[w] = word{wc.Ham, wc.Spam}
342 return fmt.Errorf("fetching words: %s", err)
347// WordScore is a word with its score as used in classifications, based on
348// (historic) training.
349type WordScore struct {
351 Score float64 // 0 is ham, 1 is spam.
354// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
355func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (Result, error) {
357 return Result{}, errClosed
360 var hamHigh float64 = 0
361 var spamLow float64 = 1
362 var topHam []WordScore
363 var topSpam []WordScore
365 // Find words that should be in the database.
366 lookupWords := []string{}
367 expect := map[string]struct{}{}
368 unknowns := map[string]struct{}{}
370 for w := range words {
371 if f.bloom != nil && !f.bloom.Has(w) {
373 if len(unknowns) < 50 {
374 unknowns[w] = struct{}{}
378 if _, ok := f.cache[w]; ok {
381 lookupWords = append(lookupWords, w)
382 expect[w] = struct{}{}
384 if len(unknowns) > 0 {
385 f.log.Debug("unknown words in bloom filter, showing max 50",
386 slog.Any("words", unknowns),
387 slog.Any("totalunknown", totalUnknown),
388 slog.Any("totalwords", len(words)))
391 // Fetch words from database.
392 fetched := map[string]word{}
393 if len(lookupWords) > 0 {
394 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
397 for w, c := range fetched {
401 f.log.Debug("unknown words in db",
402 slog.Any("words", expect),
403 slog.Any("totalunknown", len(expect)),
404 slog.Any("totalwords", len(words)))
407 for w := range words {
414 wS = float64(c.Spam) / float64(f.spams)
417 wH = float64(c.Ham) / float64(f.hams)
423 } else if r >= 1-f.MaxPower {
427 if c.Ham+c.Spam <= uint32(f.RareWords) {
428 // Reduce the power of rare words.
429 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
431 if math.Abs(0.5-r) < f.IgnoreWords {
435 if len(topHam) >= f.TopWords && r > hamHigh {
438 topHam = append(topHam, WordScore{w, r})
443 if len(topSpam) >= f.TopWords && r < spamLow {
446 topSpam = append(topSpam, WordScore{w, r})
453 sort.Slice(topHam, func(i, j int) bool {
454 a, b := topHam[i], topHam[j]
455 if a.Score == b.Score {
456 return len(a.Word) > len(b.Word)
458 return a.Score < b.Score
460 sort.Slice(topSpam, func(i, j int) bool {
461 a, b := topSpam[i], topSpam[j]
462 if a.Score == b.Score {
463 return len(a.Word) > len(b.Word)
465 return a.Score > b.Score
469 if nham > len(topHam) {
473 if nspam > len(topSpam) {
476 topHam = topHam[:nham]
477 topSpam = topSpam[:nspam]
480 for _, x := range topHam {
481 eta += math.Log(1-x.Score) - math.Log(x.Score)
483 for _, x := range topSpam {
484 eta += math.Log(1-x.Score) - math.Log(x.Score)
487 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
489 prob := 1 / (1 + math.Pow(math.E, eta))
491 // We want at least some positive signals, otherwise a few negative signals can
492 // mark incoming messages as spam too easily. If we have no negative signals, more
493 // messages will be classified as ham and accepted. This is fine, the user will
494 // classify it such, and retrain the filter. We mostly want to avoid rejecting too
495 // much when there isn't enough signal.
496 significant := f.hams >= 50
498 return Result{prob, significant, words, topHam, topSpam}, nil
501// Result is a successful classification, whether positive or negative.
503 Probability float64 // Between 0 (ham) and 1 (spam).
504 Significant bool // If true, enough classified words are available to base decisions on.
505 Words map[string]struct{}
506 Hams, Spams []WordScore
509// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
510func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (Result, error) {
512 return Result{}, errClosed
515 mf, err := os.Open(path)
521 f.log.Check(err, "closing file after classify")
527 return f.ClassifyMessageReader(ctx, mf, fi.Size())
530func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (Result, error) {
531 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
532 if err != nil && errors.Is(err, message.ErrBadContentType) {
533 // Invalid content-type header is a sure sign of spam.
534 //f.log.Infox("parsing content", err)
535 return Result{Probability: 1, Significant: true}, nil
537 return f.ClassifyMessage(ctx, m)
540// ClassifyMessage parses the mail message in r and returns the spam probability
541// (between 0 and 1), along with the tokenized words found in the message, and the
542// ham and spam words and their scores used.
543func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (Result, error) {
544 words, err := f.ParseMessage(m)
548 return f.ClassifyWords(ctx, words)
551// Train adds the words of a single message to the filter.
552func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
553 if err := f.ensureBloom(); err != nil {
559 for w := range words {
564 if _, ok := f.cache[w]; !ok {
565 lwords = append(lwords, w)
569 if err := f.loadCache(ctx, lwords); err != nil {
580 for w := range words {
593func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
594 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
595 words, err := f.ParseMessage(p)
597 return fmt.Errorf("parsing mail contents: %v", err)
599 return f.Train(ctx, ham, words)
602func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
603 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
604 words, err := f.ParseMessage(p)
606 return fmt.Errorf("parsing mail contents: %v", err)
608 return f.Untrain(ctx, ham, words)
611func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
612 if len(lwords) == 0 {
616 return loadWords(ctx, f.db, lwords, f.cache)
619// Untrain adjusts the filter to undo a previous training of the words.
620func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
621 if err := f.ensureBloom(); err != nil {
625 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
627 for w := range words {
628 if _, ok := f.cache[w]; !ok {
629 lwords = append(lwords, w)
632 if err := f.loadCache(ctx, lwords); err != nil {
636 // Modify the message count.
645 f.log.Error("attempt to decrease ham/spam message count while already zero", slog.Bool("ham", ham))
650 // Decrease the word counts.
651 for w := range words {
663 f.log.Error("attempt to decrease ham/spam word count while already zero", slog.String("word", w), slog.Bool("ham", ham))
673// TrainDir parses mail messages from files and trains the filter.
674func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
676 return 0, 0, errClosed
678 if err := f.ensureBloom(); err != nil {
682 for _, name := range files {
683 p := filepath.Join(dir, name)
684 valid, words, err := f.tokenizeMail(p)
686 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
694 for w := range words {
713// TrainDirs trains and saves a filter with mail messages from different types
715func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
723 var hamMalformed, sentMalformed, spamMalformed uint32
726 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
730 tham := time.Since(start)
735 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
740 tsent := time.Since(start)
743 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
747 tspam := time.Since(start)
751 if err := f.Save(); err != nil {
752 return fmt.Errorf("saving filter: %s", err)
755 dbSize := f.fileSize(f.dbPath)
756 bloomSize := f.fileSize(f.bloomPath)
758 f.log.Print("training done",
759 slog.Any("hams", hams),
760 slog.Any("hamtime", tham),
761 slog.Any("hammalformed", hamMalformed),
762 slog.Any("sent", sent),
763 slog.Any("senttime", tsent),
764 slog.Any("sentmalformed", sentMalformed),
765 slog.Any("spams", f.spams),
766 slog.Any("spamtime", tspam),
767 slog.Any("spammalformed", spamMalformed),
768 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
769 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
770 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
776func (f *Filter) fileSize(p string) int {
777 fi, err := os.Stat(p)
779 f.log.Infox("stat", err, slog.Any("path", p))
782 return int(fi.Size())
785// DB returns the database, for backups.
786func (f *Filter) DB() *bstore.DB {