1package junk
2
3import (
4 "context"
5 "fmt"
6 "math"
7 "os"
8 "path/filepath"
9 "testing"
10
11 "github.com/mjl-/mox/mlog"
12)
13
14var ctxbg = context.Background()
15
16func tcheck(t *testing.T, err error, msg string) {
17 t.Helper()
18 if err != nil {
19 t.Fatalf("%s: %s", msg, err)
20 }
21}
22
23func tlistdir(t *testing.T, name string) []string {
24 t.Helper()
25 l, err := os.ReadDir(name)
26 tcheck(t, err, "readdir")
27 names := make([]string, len(l))
28 for i, e := range l {
29 names[i] = e.Name()
30 }
31 return names
32}
33
34func TestFilter(t *testing.T) {
35 log := mlog.New("junk", nil)
36 params := Params{
37 Onegrams: true,
38 Twograms: true,
39 Threegrams: false,
40 MaxPower: 0.1,
41 TopWords: 10,
42 IgnoreWords: 0.1,
43 RareWords: 1,
44 }
45 dbPath := filepath.FromSlash("../testdata/junk/filter.db")
46 bloomPath := filepath.FromSlash("../testdata/junk/filter.bloom")
47 os.Remove(dbPath)
48 os.Remove(bloomPath)
49 f, err := NewFilter(ctxbg, log, params, dbPath, bloomPath)
50 tcheck(t, err, "new filter")
51 err = f.Close()
52 tcheck(t, err, "close filter")
53
54 f, err = OpenFilter(ctxbg, log, params, dbPath, bloomPath, true)
55 tcheck(t, err, "open filter")
56
57 // Ensure these dirs exist. Developers should bring their own ham/spam example
58 // emails.
59 os.MkdirAll("../testdata/train/ham", 0770)
60 os.MkdirAll("../testdata/train/spam", 0770)
61
62 hamdir := filepath.FromSlash("../testdata/train/ham")
63 spamdir := filepath.FromSlash("../testdata/train/spam")
64 hamfiles := tlistdir(t, hamdir)
65 if len(hamfiles) > 100 {
66 hamfiles = hamfiles[:100]
67 }
68 spamfiles := tlistdir(t, spamdir)
69 if len(spamfiles) > 100 {
70 spamfiles = spamfiles[:100]
71 }
72
73 err = f.TrainDirs(hamdir, "", spamdir, hamfiles, nil, spamfiles)
74 tcheck(t, err, "train dirs")
75
76 if len(hamfiles) == 0 || len(spamfiles) == 0 {
77 fmt.Println("not training, no ham and/or spam messages, add them to testdata/train/ham and testdata/train/spam")
78 return
79 }
80
81 prob, _, _, _, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0]))
82 tcheck(t, err, "classify ham message")
83 if prob > 0.1 {
84 t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
85 }
86
87 prob, _, _, _, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0]))
88 tcheck(t, err, "classify spam message")
89 if prob < 0.9 {
90 t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
91 }
92
93 err = f.Close()
94 tcheck(t, err, "close filter")
95
96 // Start again with empty filter. We'll train a few messages and check they are
97 // classified as ham/spam. Then we untrain to see they are no longer classified.
98 os.Remove(dbPath)
99 os.Remove(bloomPath)
100 f, err = NewFilter(ctxbg, log, params, dbPath, bloomPath)
101 tcheck(t, err, "open filter")
102
103 hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
104 tcheck(t, err, "open hamfile")
105 defer hamf.Close()
106 hamstat, err := hamf.Stat()
107 tcheck(t, err, "stat hamfile")
108 hamsize := hamstat.Size()
109
110 spamf, err := os.Open(filepath.Join(spamdir, spamfiles[0]))
111 tcheck(t, err, "open spamfile")
112 defer spamf.Close()
113 spamstat, err := spamf.Stat()
114 tcheck(t, err, "stat spamfile")
115 spamsize := spamstat.Size()
116
117 // Train each message twice, to prevent single occurrences from being ignored.
118 err = f.TrainMessage(ctxbg, hamf, hamsize, true)
119 tcheck(t, err, "train ham message")
120 _, err = hamf.Seek(0, 0)
121 tcheck(t, err, "seek ham message")
122 err = f.TrainMessage(ctxbg, hamf, hamsize, true)
123 tcheck(t, err, "train ham message")
124
125 err = f.TrainMessage(ctxbg, spamf, spamsize, false)
126 tcheck(t, err, "train spam message")
127 _, err = spamf.Seek(0, 0)
128 tcheck(t, err, "seek spam message")
129 err = f.TrainMessage(ctxbg, spamf, spamsize, true)
130 tcheck(t, err, "train spam message")
131
132 if !f.modified {
133 t.Fatalf("filter not modified after training")
134 }
135 if !f.bloom.Modified() {
136 t.Fatalf("bloom filter not modified after training")
137 }
138
139 err = f.Save()
140 tcheck(t, err, "save filter")
141 if f.modified || f.bloom.Modified() {
142 t.Fatalf("filter or bloom filter still modified after save")
143 }
144
145 // Classify and verify.
146 _, err = hamf.Seek(0, 0)
147 tcheck(t, err, "seek ham message")
148 prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
149 tcheck(t, err, "classify ham")
150 if prob > 0.1 {
151 t.Fatalf("got prob %v, expected <= 0.1", prob)
152 }
153
154 _, err = spamf.Seek(0, 0)
155 tcheck(t, err, "seek spam message")
156 prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
157 tcheck(t, err, "classify spam")
158 if prob < 0.9 {
159 t.Fatalf("got prob %v, expected >= 0.9", prob)
160 }
161
162 // Untrain ham & spam.
163 _, err = hamf.Seek(0, 0)
164 tcheck(t, err, "seek ham message")
165 err = f.UntrainMessage(ctxbg, hamf, hamsize, true)
166 tcheck(t, err, "untrain ham message")
167 _, err = hamf.Seek(0, 0)
168 tcheck(t, err, "seek ham message")
169 err = f.UntrainMessage(ctxbg, hamf, spamsize, true)
170 tcheck(t, err, "untrain ham message")
171
172 _, err = spamf.Seek(0, 0)
173 tcheck(t, err, "seek spam message")
174 err = f.UntrainMessage(ctxbg, spamf, spamsize, true)
175 tcheck(t, err, "untrain spam message")
176 _, err = spamf.Seek(0, 0)
177 tcheck(t, err, "seek spam message")
178 err = f.UntrainMessage(ctxbg, spamf, spamsize, true)
179 tcheck(t, err, "untrain spam message")
180
181 if !f.modified {
182 t.Fatalf("filter not modified after untraining")
183 }
184
185 // Classify again, should be unknown.
186 _, err = hamf.Seek(0, 0)
187 tcheck(t, err, "seek ham message")
188 prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
189 tcheck(t, err, "classify ham")
190 if math.Abs(prob-0.5) > 0.1 {
191 t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
192 }
193
194 _, err = spamf.Seek(0, 0)
195 tcheck(t, err, "seek spam message")
196 prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
197 tcheck(t, err, "classify spam")
198 if math.Abs(prob-0.5) > 0.1 {
199 t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
200 }
201
202 err = f.Close()
203 tcheck(t, err, "close filter")
204}
205