1// Package junk implements a bayesian spam filter.
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
25 "github.com/mjl-/bstore"
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
32 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
33 errClosed = errors.New("filter is closed")
41type wordscore struct {
47// Params holds parameters for the filter. Most are at test-time. The first are
48// used during parsing and training.
50 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
51 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
52 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
53 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
54 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
55 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
56 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
59var DBTypes = []any{wordscore{}} // Stored in DB.
64 log mlog.Log // For logging cid.
66 modified bool // Whether any modifications are pending. Cleared by Save.
67 hams, spams uint32 // Message count, stored in db under word "-".
68 cache map[string]word // Words read from database or during training.
69 changed map[string]word // Words modified during training.
70 dbPath, bloomPath string
71 db *bstore.DB // Always open on a filter.
72 bloom *Bloom // Only opened when writing.
73 isNew bool // Set for new filters until their first sync to disk. For faster writing.
76func (f *Filter) ensureBloom() error {
81 f.bloom, err = openBloom(f.bloomPath)
85// CloseDiscard closes the filter, discarding any changes.
86func (f *Filter) CloseDiscard() error {
91 *f = Filter{log: f.log, closed: true}
95// Close first saves the filter if it has modifications, then closes the database
96// connection and releases the bloom filter.
97func (f *Filter) Close() error {
110 *f = Filter{log: f.log, closed: true}
114func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
118 bloom, err = openBloom(bloomPath)
122 } else if fi, err := os.Stat(bloomPath); err == nil {
123 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
124 return nil, fmt.Errorf("bloom: %s", err)
128 db, err := openDB(ctx, log, dbPath)
130 return nil, fmt.Errorf("open database: %s", err)
136 cache: map[string]word{},
137 changed: map[string]word{},
139 bloomPath: bloomPath,
143 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
144 wc := wordscore{Word: "-"}
152 log.Check(cerr, "closing filter after error")
153 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
158// NewFilter creates a new filter with empty bloom filter and database files. The
159// filter is marked as new until the first save, will be done automatically if
160// TrainDirs is called. If the bloom and/or database files exist, an error is
162func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
164 if _, err := os.Stat(bloomPath); err == nil {
165 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
166 } else if _, err := os.Stat(dbPath); err == nil {
167 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
170 bloomSizeBytes := 4 * 1024 * 1024
171 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
172 return nil, fmt.Errorf("bloom: %s", err)
174 bf, err := os.Create(bloomPath)
176 return nil, fmt.Errorf("creating bloom file: %w", err)
178 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
180 log.Check(xerr, "closing bloom filter file after truncate error")
181 xerr = os.Remove(bloomPath)
182 log.Check(xerr, "removing bloom filter file after truncate error")
183 return nil, fmt.Errorf("making empty bloom filter: %s", err)
186 log.Check(err, "closing bloomfilter file")
188 db, err := newDB(ctx, log, dbPath)
190 xerr := os.Remove(bloomPath)
191 log.Check(xerr, "removing bloom filter file after db init error")
192 xerr = os.Remove(dbPath)
193 log.Check(xerr, "removing database file after db init error")
194 return nil, fmt.Errorf("open database: %s", err)
197 words := map[string]word{} // f.changed is set to new map after training
201 modified: true, // Ensure ham/spam message count is added for new filter.
205 bloomPath: bloomPath,
214func openBloom(path string) (*Bloom, error) {
215 buf, err := os.ReadFile(path)
217 return nil, fmt.Errorf("reading bloom file: %w", err)
219 return NewBloom(buf, bloomK)
222func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
223 // Remove any existing files.
228 err := os.Remove(path)
229 log.Check(err, "removing db file after init error")
233 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
234 db, err := bstore.Open(ctx, path, &opts, DBTypes...)
236 return nil, fmt.Errorf("open new database: %w", err)
241func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
242 if _, err := os.Stat(path); err != nil {
243 return nil, fmt.Errorf("stat db file: %w", err)
245 opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
246 return bstore.Open(ctx, path, &opts, DBTypes...)
249// Save stores modifications, e.g. from training, to the database and bloom
251func (f *Filter) Save() error {
259 if f.bloom != nil && f.bloom.Modified() {
260 if err := f.bloom.Write(f.bloomPath); err != nil {
261 return fmt.Errorf("writing bloom filter: %w", err)
265 // We need to insert sequentially for reasonable performance.
266 words := make([]string, len(f.changed))
268 for w := range f.changed {
272 sort.Slice(words, func(i, j int) bool {
273 return words[i] < words[j]
276 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
277 // start := time.Now()
279 if err := f.db.HintAppend(true, wordscore{}); err != nil {
280 f.log.Errorx("hint appendonly", err)
283 err := f.db.HintAppend(false, wordscore{})
284 f.log.Check(err, "restoring append hint")
288 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
289 update := func(w string, ham, spam uint32) error {
291 return tx.Insert(&wordscore{w, ham, spam})
294 wc := wordscore{w, 0, 0}
296 if err == bstore.ErrAbsent {
297 return tx.Insert(&wordscore{w, ham, spam})
298 } else if err != nil {
301 return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
303 if err := update("-", f.hams, f.spams); err != nil {
304 return fmt.Errorf("storing total ham/spam message count: %s", err)
307 for _, w := range words {
309 if err := update(w, c.Ham, c.Spam); err != nil {
310 return fmt.Errorf("updating ham/spam count: %s", err)
316 return fmt.Errorf("updating database: %w", err)
319 f.changed = map[string]word{}
322 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
326func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
327 sort.Slice(l, func(i, j int) bool {
331 err := db.Read(ctx, func(tx *bstore.Tx) error {
332 for _, w := range l {
333 wc := wordscore{Word: w}
334 if err := tx.Get(&wc); err == nil {
335 dst[w] = word{wc.Ham, wc.Spam}
341 return fmt.Errorf("fetching words: %s", err)
346// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
347func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
349 return 0, 0, 0, errClosed
357 var hamHigh float64 = 0
358 var spamLow float64 = 1
362 // Find words that should be in the database.
363 lookupWords := []string{}
364 expect := map[string]struct{}{}
365 unknowns := map[string]struct{}{}
367 for w := range words {
368 if f.bloom != nil && !f.bloom.Has(w) {
370 if len(unknowns) < 50 {
371 unknowns[w] = struct{}{}
375 if _, ok := f.cache[w]; ok {
378 lookupWords = append(lookupWords, w)
379 expect[w] = struct{}{}
381 if len(unknowns) > 0 {
382 f.log.Debug("unknown words in bloom filter, showing max 50",
383 slog.Any("words", unknowns),
384 slog.Any("totalunknown", totalUnknown),
385 slog.Any("totalwords", len(words)))
388 // Fetch words from database.
389 fetched := map[string]word{}
390 if len(lookupWords) > 0 {
391 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
394 for w, c := range fetched {
398 f.log.Debug("unknown words in db",
399 slog.Any("words", expect),
400 slog.Any("totalunknown", len(expect)),
401 slog.Any("totalwords", len(words)))
404 for w := range words {
411 wS = float64(c.Spam) / float64(f.spams)
414 wH = float64(c.Ham) / float64(f.hams)
420 } else if r >= 1-f.MaxPower {
424 if c.Ham+c.Spam <= uint32(f.RareWords) {
425 // Reduce the power of rare words.
426 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
428 if math.Abs(0.5-r) < f.IgnoreWords {
432 if len(topHam) >= f.TopWords && r > hamHigh {
435 topHam = append(topHam, xword{w, r})
440 if len(topSpam) >= f.TopWords && r < spamLow {
443 topSpam = append(topSpam, xword{w, r})
450 sort.Slice(topHam, func(i, j int) bool {
451 a, b := topHam[i], topHam[j]
453 return len(a.Word) > len(b.Word)
457 sort.Slice(topSpam, func(i, j int) bool {
458 a, b := topSpam[i], topSpam[j]
460 return len(a.Word) > len(b.Word)
466 if nham > len(topHam) {
470 if nspam > len(topSpam) {
473 topHam = topHam[:nham]
474 topSpam = topSpam[:nspam]
477 for _, x := range topHam {
478 eta += math.Log(1-x.R) - math.Log(x.R)
480 for _, x := range topSpam {
481 eta += math.Log(1-x.R) - math.Log(x.R)
484 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
486 prob := 1 / (1 + math.Pow(math.E, eta))
487 return prob, len(topHam), len(topSpam), nil
490// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
491func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
493 return 0, nil, 0, 0, errClosed
496 mf, err := os.Open(path)
498 return 0, nil, 0, 0, err
502 f.log.Check(err, "closing file after classify")
506 return 0, nil, 0, 0, err
508 return f.ClassifyMessageReader(ctx, mf, fi.Size())
511func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
512 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
513 if err != nil && errors.Is(err, message.ErrBadContentType) {
514 // Invalid content-type header is a sure sign of spam.
515 //f.log.Infox("parsing content", err)
516 return 1, nil, 0, 0, nil
518 return f.ClassifyMessage(ctx, m)
521// ClassifyMessage parses the mail message in r and returns the spam probability
522// (between 0 and 1), along with the tokenized words found in the message, and the
523// number of recognized ham and spam words.
524func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
526 words, err = f.ParseMessage(m)
528 return 0, nil, 0, 0, err
531 probability, nham, nspam, err = f.ClassifyWords(ctx, words)
532 return probability, words, nham, nspam, err
535// Train adds the words of a single message to the filter.
536func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
537 if err := f.ensureBloom(); err != nil {
543 for w := range words {
548 if _, ok := f.cache[w]; !ok {
549 lwords = append(lwords, w)
553 if err := f.loadCache(ctx, lwords); err != nil {
564 for w := range words {
577func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
578 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
579 words, err := f.ParseMessage(p)
581 return fmt.Errorf("parsing mail contents: %v", err)
583 return f.Train(ctx, ham, words)
586func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
587 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
588 words, err := f.ParseMessage(p)
590 return fmt.Errorf("parsing mail contents: %v", err)
592 return f.Untrain(ctx, ham, words)
595func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
596 if len(lwords) == 0 {
600 return loadWords(ctx, f.db, lwords, f.cache)
603// Untrain adjusts the filter to undo a previous training of the words.
604func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
605 if err := f.ensureBloom(); err != nil {
609 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
611 for w := range words {
612 if _, ok := f.cache[w]; !ok {
613 lwords = append(lwords, w)
616 if err := f.loadCache(ctx, lwords); err != nil {
620 // Modify the message count.
628 // Decrease the word counts.
629 for w := range words {
645// TrainDir parses mail messages from files and trains the filter.
646func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
648 return 0, 0, errClosed
650 if err := f.ensureBloom(); err != nil {
654 for _, name := range files {
655 p := filepath.Join(dir, name)
656 valid, words, err := f.tokenizeMail(p)
658 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
666 for w := range words {
685// TrainDirs trains and saves a filter with mail messages from different types
687func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
695 var hamMalformed, sentMalformed, spamMalformed uint32
698 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
702 tham := time.Since(start)
707 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
712 tsent := time.Since(start)
715 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
719 tspam := time.Since(start)
723 if err := f.Save(); err != nil {
724 return fmt.Errorf("saving filter: %s", err)
727 dbSize := f.fileSize(f.dbPath)
728 bloomSize := f.fileSize(f.bloomPath)
730 f.log.Print("training done",
731 slog.Any("hams", hams),
732 slog.Any("hamtime", tham),
733 slog.Any("hammalformed", hamMalformed),
734 slog.Any("sent", sent),
735 slog.Any("senttime", tsent),
736 slog.Any("sentmalformed", sentMalformed),
737 slog.Any("spams", f.spams),
738 slog.Any("spamtime", tspam),
739 slog.Any("spammalformed", spamMalformed),
740 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
741 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
742 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
748func (f *Filter) fileSize(p string) int {
749 fi, err := os.Stat(p)
751 f.log.Infox("stat", err, slog.Any("path", p))
754 return int(fi.Size())
757// DB returns the database, for backups.
758func (f *Filter) DB() *bstore.DB {