1// Package junk implements a bayesian spam filter.
3// A message can be parsed into words. Words (or pairs or triplets) can be used
4// to train the filter or to classify the message as ham or spam. Training
5// records the words in the database as ham/spam. Classifying consists of
6// calculating the ham/spam probability by combining the words in the message
7// with their ham/spam status.
10// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
11// todo: perhaps: whether anchor text in links in html are different from the url
25 "github.com/mjl-/bstore"
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
32 // errBadContentType = errors.New("bad content-type") // sure sign of spam, todo: use this error
33 errClosed = errors.New("filter is closed")
41type wordscore struct {
47// Params holds parameters for the filter. Most are at test-time. The first are
48// used during parsing and training.
50 Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
51 Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
52 Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
53 MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
54 TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
55 IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
56 RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
59var DBTypes = []any{wordscore{}} // Stored in DB.
64 log mlog.Log // For logging cid.
66 modified bool // Whether any modifications are pending. Cleared by Save.
67 hams, spams uint32 // Message count, stored in db under word "-".
68 cache map[string]word // Words read from database or during training.
69 changed map[string]word // Words modified during training.
70 dbPath, bloomPath string
71 db *bstore.DB // Always open on a filter.
72 bloom *Bloom // Only opened when writing.
73 isNew bool // Set for new filters until their first sync to disk. For faster writing.
76func (f *Filter) ensureBloom() error {
81 f.bloom, err = openBloom(f.bloomPath)
85// CloseDiscard closes the filter, discarding any changes.
86func (f *Filter) CloseDiscard() error {
91 *f = Filter{log: f.log, closed: true}
95// Close first saves the filter if it has modifications, then closes the database
96// connection and releases the bloom filter.
97func (f *Filter) Close() error {
110 *f = Filter{log: f.log, closed: true}
114func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
118 bloom, err = openBloom(bloomPath)
122 } else if fi, err := os.Stat(bloomPath); err == nil {
123 if err := BloomValid(int(fi.Size()), bloomK); err != nil {
124 return nil, fmt.Errorf("bloom: %s", err)
128 db, err := openDB(ctx, dbPath)
130 return nil, fmt.Errorf("open database: %s", err)
136 cache: map[string]word{},
137 changed: map[string]word{},
139 bloomPath: bloomPath,
143 err = f.db.Read(ctx, func(tx *bstore.Tx) error {
144 wc := wordscore{Word: "-"}
152 log.Check(cerr, "closing filter after error")
153 return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
158// NewFilter creates a new filter with empty bloom filter and database files. The
159// filter is marked as new until the first save, will be done automatically if
160// TrainDirs is called. If the bloom and/or database files exist, an error is
162func NewFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
164 if _, err := os.Stat(bloomPath); err == nil {
165 return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
166 } else if _, err := os.Stat(dbPath); err == nil {
167 return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
170 bloomSizeBytes := 4 * 1024 * 1024
171 if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
172 return nil, fmt.Errorf("bloom: %s", err)
174 bf, err := os.Create(bloomPath)
176 return nil, fmt.Errorf("creating bloom file: %w", err)
178 if err := bf.Truncate(4 * 1024 * 1024); err != nil {
180 log.Check(xerr, "closing bloom filter file after truncate error")
181 xerr = os.Remove(bloomPath)
182 log.Check(xerr, "removing bloom filter file after truncate error")
183 return nil, fmt.Errorf("making empty bloom filter: %s", err)
186 log.Check(err, "closing bloomfilter file")
188 db, err := newDB(ctx, log, dbPath)
190 xerr := os.Remove(bloomPath)
191 log.Check(xerr, "removing bloom filter file after db init error")
192 xerr = os.Remove(dbPath)
193 log.Check(xerr, "removing database file after db init error")
194 return nil, fmt.Errorf("open database: %s", err)
197 words := map[string]word{} // f.changed is set to new map after training
201 modified: true, // Ensure ham/spam message count is added for new filter.
205 bloomPath: bloomPath,
214func openBloom(path string) (*Bloom, error) {
215 buf, err := os.ReadFile(path)
217 return nil, fmt.Errorf("reading bloom file: %w", err)
219 return NewBloom(buf, bloomK)
222func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr error) {
223 // Remove any existing files.
228 err := os.Remove(path)
229 log.Check(err, "removing db file after init error")
233 db, err := bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
235 return nil, fmt.Errorf("open new database: %w", err)
240func openDB(ctx context.Context, path string) (*bstore.DB, error) {
241 if _, err := os.Stat(path); err != nil {
242 return nil, fmt.Errorf("stat db file: %w", err)
244 return bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
247// Save stores modifications, e.g. from training, to the database and bloom
249func (f *Filter) Save() error {
257 if f.bloom != nil && f.bloom.Modified() {
258 if err := f.bloom.Write(f.bloomPath); err != nil {
259 return fmt.Errorf("writing bloom filter: %w", err)
263 // We need to insert sequentially for reasonable performance.
264 words := make([]string, len(f.changed))
266 for w := range f.changed {
270 sort.Slice(words, func(i, j int) bool {
271 return words[i] < words[j]
274 f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed)))
275 // start := time.Now()
277 if err := f.db.HintAppend(true, wordscore{}); err != nil {
278 f.log.Errorx("hint appendonly", err)
281 err := f.db.HintAppend(false, wordscore{})
282 f.log.Check(err, "restoring append hint")
286 err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
287 update := func(w string, ham, spam uint32) error {
289 return tx.Insert(&wordscore{w, ham, spam})
292 wc := wordscore{w, 0, 0}
294 if err == bstore.ErrAbsent {
295 return tx.Insert(&wordscore{w, ham, spam})
296 } else if err != nil {
299 return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
301 if err := update("-", f.hams, f.spams); err != nil {
302 return fmt.Errorf("storing total ham/spam message count: %s", err)
305 for _, w := range words {
307 if err := update(w, c.Ham, c.Spam); err != nil {
308 return fmt.Errorf("updating ham/spam count: %s", err)
314 return fmt.Errorf("updating database: %w", err)
317 f.changed = map[string]word{}
320 // f.log.Info("wrote filter to db", slog.Any("duration", time.Since(start)))
324func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
325 sort.Slice(l, func(i, j int) bool {
329 err := db.Read(ctx, func(tx *bstore.Tx) error {
330 for _, w := range l {
331 wc := wordscore{Word: w}
332 if err := tx.Get(&wc); err == nil {
333 dst[w] = word{wc.Ham, wc.Spam}
339 return fmt.Errorf("fetching words: %s", err)
344// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
345func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
347 return 0, 0, 0, errClosed
355 var hamHigh float64 = 0
356 var spamLow float64 = 1
360 // Find words that should be in the database.
361 lookupWords := []string{}
362 expect := map[string]struct{}{}
363 unknowns := map[string]struct{}{}
365 for w := range words {
366 if f.bloom != nil && !f.bloom.Has(w) {
368 if len(unknowns) < 50 {
369 unknowns[w] = struct{}{}
373 if _, ok := f.cache[w]; ok {
376 lookupWords = append(lookupWords, w)
377 expect[w] = struct{}{}
379 if len(unknowns) > 0 {
380 f.log.Debug("unknown words in bloom filter, showing max 50",
381 slog.Any("words", unknowns),
382 slog.Any("totalunknown", totalUnknown),
383 slog.Any("totalwords", len(words)))
386 // Fetch words from database.
387 fetched := map[string]word{}
388 if len(lookupWords) > 0 {
389 if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
392 for w, c := range fetched {
396 f.log.Debug("unknown words in db",
397 slog.Any("words", expect),
398 slog.Any("totalunknown", len(expect)),
399 slog.Any("totalwords", len(words)))
402 for w := range words {
409 wS = float64(c.Spam) / float64(f.spams)
412 wH = float64(c.Ham) / float64(f.hams)
418 } else if r >= 1-f.MaxPower {
422 if c.Ham+c.Spam <= uint32(f.RareWords) {
423 // Reduce the power of rare words.
424 r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
426 if math.Abs(0.5-r) < f.IgnoreWords {
430 if len(topHam) >= f.TopWords && r > hamHigh {
433 topHam = append(topHam, xword{w, r})
438 if len(topSpam) >= f.TopWords && r < spamLow {
441 topSpam = append(topSpam, xword{w, r})
448 sort.Slice(topHam, func(i, j int) bool {
449 a, b := topHam[i], topHam[j]
451 return len(a.Word) > len(b.Word)
455 sort.Slice(topSpam, func(i, j int) bool {
456 a, b := topSpam[i], topSpam[j]
458 return len(a.Word) > len(b.Word)
464 if nham > len(topHam) {
468 if nspam > len(topSpam) {
471 topHam = topHam[:nham]
472 topSpam = topSpam[:nspam]
475 for _, x := range topHam {
476 eta += math.Log(1-x.R) - math.Log(x.R)
478 for _, x := range topSpam {
479 eta += math.Log(1-x.R) - math.Log(x.R)
482 f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
484 prob := 1 / (1 + math.Pow(math.E, eta))
485 return prob, len(topHam), len(topSpam), nil
488// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
489func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
491 return 0, nil, 0, 0, errClosed
494 mf, err := os.Open(path)
496 return 0, nil, 0, 0, err
500 f.log.Check(err, "closing file after classify")
504 return 0, nil, 0, 0, err
506 return f.ClassifyMessageReader(ctx, mf, fi.Size())
509func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
510 m, err := message.EnsurePart(f.log.Logger, false, mf, size)
511 if err != nil && errors.Is(err, message.ErrBadContentType) {
512 // Invalid content-type header is a sure sign of spam.
513 //f.log.Infox("parsing content", err)
514 return 1, nil, 0, 0, nil
516 return f.ClassifyMessage(ctx, m)
519// ClassifyMessage parses the mail message in r and returns the spam probability
520// (between 0 and 1), along with the tokenized words found in the message, and the
521// number of recognized ham and spam words.
522func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
524 words, err = f.ParseMessage(m)
526 return 0, nil, 0, 0, err
529 probability, nham, nspam, err = f.ClassifyWords(ctx, words)
530 return probability, words, nham, nspam, err
533// Train adds the words of a single message to the filter.
534func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
535 if err := f.ensureBloom(); err != nil {
541 for w := range words {
546 if _, ok := f.cache[w]; !ok {
547 lwords = append(lwords, w)
551 if err := f.loadCache(ctx, lwords); err != nil {
562 for w := range words {
575func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
576 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
577 words, err := f.ParseMessage(p)
579 return fmt.Errorf("parsing mail contents: %v", err)
581 return f.Train(ctx, ham, words)
584func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
585 p, _ := message.EnsurePart(f.log.Logger, false, r, size)
586 words, err := f.ParseMessage(p)
588 return fmt.Errorf("parsing mail contents: %v", err)
590 return f.Untrain(ctx, ham, words)
593func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
594 if len(lwords) == 0 {
598 return loadWords(ctx, f.db, lwords, f.cache)
601// Untrain adjusts the filter to undo a previous training of the words.
602func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
603 if err := f.ensureBloom(); err != nil {
607 // Lookup any words from the db that aren't in the cache and put them in the cache for modification.
609 for w := range words {
610 if _, ok := f.cache[w]; !ok {
611 lwords = append(lwords, w)
614 if err := f.loadCache(ctx, lwords); err != nil {
618 // Modify the message count.
626 // Decrease the word counts.
627 for w := range words {
643// TrainDir parses mail messages from files and trains the filter.
644func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
646 return 0, 0, errClosed
648 if err := f.ensureBloom(); err != nil {
652 for _, name := range files {
653 p := filepath.Join(dir, name)
654 valid, words, err := f.tokenizeMail(p)
656 // f.log.Infox("tokenizing mail", err, slog.Any("path", p))
664 for w := range words {
683// TrainDirs trains and saves a filter with mail messages from different types
685func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
693 var hamMalformed, sentMalformed, spamMalformed uint32
696 f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
700 tham := time.Since(start)
705 sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
710 tsent := time.Since(start)
713 f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
717 tspam := time.Since(start)
721 if err := f.Save(); err != nil {
722 return fmt.Errorf("saving filter: %s", err)
725 dbSize := f.fileSize(f.dbPath)
726 bloomSize := f.fileSize(f.bloomPath)
728 f.log.Print("training done",
729 slog.Any("hams", hams),
730 slog.Any("hamtime", tham),
731 slog.Any("hammalformed", hamMalformed),
732 slog.Any("sent", sent),
733 slog.Any("senttime", tsent),
734 slog.Any("sentmalformed", sentMalformed),
735 slog.Any("spams", f.spams),
736 slog.Any("spamtime", tspam),
737 slog.Any("spammalformed", spamMalformed),
738 slog.Any("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
739 slog.Any("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
740 slog.Any("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
746func (f *Filter) fileSize(p string) int {
747 fi, err := os.Stat(p)
749 f.log.Infox("stat", err, slog.Any("path", p))
752 return int(fi.Size())
755// DB returns the database, for backups.
756func (f *Filter) DB() *bstore.DB {