1package store
2
3import (
4 "context"
5 cryptorand "crypto/rand"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "log/slog"
10 "runtime/debug"
11 "sync"
12 "time"
13
14 "github.com/mjl-/bstore"
15
16 "github.com/mjl-/mox/metrics"
17 "github.com/mjl-/mox/mlog"
18 "github.com/mjl-/mox/mox-"
19)
20
21const sessionsPerAccount = 100 // We remove the oldest when 100th is added.
22const sessionLifetime = 24 * time.Hour // Extended automatically by use.
23const sessionWriteDelay = 5 * time.Minute // Per account, for coalescing writes.
24
25var sessions = struct {
26 sync.Mutex
27
28 // For each account, we keep all sessions (with fixed maximum number) in memory. If
29 // the map for an account is nil, it is initialized from the database on first use.
30 accounts map[string]map[SessionToken]LoginSession
31
32 // We flush sessions with extended expiration timestamp to disk with a delay, to
33 // coalesce potentially many changes. The delay is short enough that we don't have
34 // to care about flushing to disk on shutdown.
35 pendingFlushes map[string]map[SessionToken]struct{}
36}{
37 accounts: map[string]map[SessionToken]LoginSession{},
38 pendingFlushes: map[string]map[SessionToken]struct{}{},
39}
40
41// Ensure sessions for account are initialized from database. If the sessions were
42// initialized from the database, or when alwaysOpenAccount is true, an open
43// account is returned (assuming no error occurred).
44//
45// must be called with sessions lock held.
46func ensureAccountSessions(ctx context.Context, log mlog.Log, accountName string, alwaysOpenAccount bool) (*Account, error) {
47 var acc *Account
48 accSessions := sessions.accounts[accountName]
49 if accSessions == nil {
50 var err error
51 acc, err = OpenAccount(log, accountName)
52 if err != nil {
53 return nil, err
54 }
55
56 // We still hold the lock, not great...
57
58 accSessions = map[SessionToken]LoginSession{}
59 err = bstore.QueryDB[LoginSession](ctx, acc.DB).ForEach(func(ls LoginSession) error {
60 // We keep strings around for easy comparison.
61 ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
62 ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
63
64 accSessions[ls.sessionToken] = ls
65 return nil
66 })
67 if err != nil {
68 return nil, err
69 }
70
71 sessions.accounts[accountName] = accSessions
72 }
73 if acc == nil && alwaysOpenAccount {
74 return OpenAccount(log, accountName)
75 }
76 return acc, nil
77}
78
79// SessionUse checks if a session is valid. If csrfToken is the empty string, no
80// CSRF check is done. Otherwise it must be the csrf token associated with the
81// session token.
82func SessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
83 sessions.Lock()
84 defer sessions.Unlock()
85
86 acc, err := ensureAccountSessions(ctx, log, accountName, false)
87 if err != nil {
88 return LoginSession{}, err
89 } else if acc != nil {
90 if err := acc.Close(); err != nil {
91 return LoginSession{}, fmt.Errorf("closing account: %w", err)
92 }
93 }
94
95 return sessionUse(ctx, log, accountName, sessionToken, csrfToken)
96}
97
98// must be called with sessions lock held.
99func sessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
100 // Check if valid.
101 ls, ok := sessions.accounts[accountName][sessionToken]
102 if !ok {
103 return LoginSession{}, fmt.Errorf("unknown session token")
104 } else if time.Until(ls.Expires) < 0 {
105 return LoginSession{}, fmt.Errorf("session expired")
106 } else if csrfToken != "" && csrfToken != ls.csrfToken {
107 return LoginSession{}, fmt.Errorf("mismatch between csrf and session tokens")
108 }
109
110 // Extend lifetime.
111 ls.Expires = time.Now().Add(sessionLifetime)
112 sessions.accounts[accountName][sessionToken] = ls
113
114 // If we haven't scheduled a flush to database yet, schedule one now.
115 if sessions.pendingFlushes[accountName] == nil {
116 sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
117 go func() {
118 pkglog := mlog.New("store", nil)
119
120 defer func() {
121 x := recover()
122 if x != nil {
123 pkglog.Error("recover from panic", slog.Any("panic", x))
124 debug.PrintStack()
125 metrics.PanicInc(metrics.Store)
126 }
127 }()
128
129 time.Sleep(sessionWriteDelay)
130 sessionsDelayedFlush(pkglog, accountName)
131 }()
132 }
133 sessions.pendingFlushes[accountName][ls.sessionToken] = struct{}{}
134
135 return ls, nil
136}
137
138// wait, then flush all changed sessions for an account.
139func sessionsDelayedFlush(log mlog.Log, accountName string) {
140 sessions.Lock()
141 defer sessions.Unlock()
142
143 sessionTokens := sessions.pendingFlushes[accountName]
144 delete(sessions.pendingFlushes, accountName)
145
146 _, ok := sessions.accounts[accountName]
147 if !ok {
148 // Account may have been removed. Nothing to flush.
149 return
150 }
151
152 acc, err := OpenAccount(log, accountName)
153 if err != nil && errors.Is(err, ErrAccountUnknown) {
154 // Account may have been removed. Nothing to flush.
155 log.Infox("flushing sessions for account", err, slog.String("account", accountName))
156 return
157 }
158 if err != nil {
159 log.Errorx("open account for flushing changed session tokens", err, slog.String("account", accountName))
160 return
161 }
162 defer func() {
163 err := acc.Close()
164 log.Check(err, "closing account")
165 }()
166
167 err = acc.DB.Write(mox.Context, func(tx *bstore.Tx) error {
168 for sessionToken := range sessionTokens {
169 ls, ok := sessions.accounts[accountName][sessionToken]
170 if !ok {
171 return fmt.Errorf("unknown session token to flush")
172 }
173 if err := tx.Update(&ls); err != nil {
174 return err
175 }
176 }
177 return nil
178 })
179 log.Check(err, "flushing changed sessions for account", slog.String("account", accountName))
180}
181
182// SessionAddTokens adds a prepared or pre-existing LoginSession to the database and
183// cache. Can be used to restore a session token that was used to reset a password.
184func SessionAddToken(ctx context.Context, log mlog.Log, ls *LoginSession) error {
185 sessions.Lock()
186 defer sessions.Unlock()
187
188 acc, err := ensureAccountSessions(ctx, log, ls.AccountName, true)
189 if err != nil {
190 return err
191 }
192 defer func() {
193 err := acc.Close()
194 log.Check(err, "closing account after adding session token")
195 }()
196
197 return sessionAddToken(ctx, log, acc, ls)
198}
199
200// caller must hold sessions lock.
201func sessionAddToken(ctx context.Context, log mlog.Log, acc *Account, ls *LoginSession) error {
202 ls.ID = 0
203
204 err := acc.DB.Write(ctx, func(tx *bstore.Tx) error {
205 // Remove sessions if we have too many, starting with expired sessions, and
206 // removing the oldest if needed.
207 if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
208 var oldest LoginSession
209 for _, ols := range sessions.accounts[ls.AccountName] {
210 if time.Until(ols.Expires) < 0 {
211 if err := tx.Delete(&ols); err != nil {
212 return err
213 }
214 delete(sessions.accounts[ls.AccountName], ols.sessionToken)
215 continue
216 }
217 if oldest.ID == 0 || ols.Expires.Before(oldest.Expires) {
218 oldest = ols
219 }
220 }
221 if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
222 if err := tx.Delete(&oldest); err != nil {
223 return err
224 }
225 delete(sessions.accounts[ls.AccountName], oldest.sessionToken)
226 }
227 }
228
229 if err := tx.Insert(ls); err != nil {
230 return fmt.Errorf("insert: %v", err)
231 }
232 return nil
233 })
234 if err != nil {
235 return err
236 }
237 sessions.accounts[ls.AccountName][ls.sessionToken] = *ls
238 return nil
239}
240
241// SessionAdd creates a new session token, with csrf token, and adds it to the
242// database and in-memory session cache. If there are too many sessions, the oldest
243// is removed.
244func SessionAdd(ctx context.Context, log mlog.Log, accountName, loginAddress string) (session SessionToken, csrf CSRFToken, rerr error) {
245 // Prepare new LoginSession.
246 ls := LoginSession{0, time.Time{}, time.Now().Add(sessionLifetime), [16]byte{}, [16]byte{}, accountName, loginAddress, "", ""}
247 if _, err := cryptorand.Read(ls.SessionTokenBinary[:]); err != nil {
248 return "", "", err
249 }
250 if _, err := cryptorand.Read(ls.CSRFTokenBinary[:]); err != nil {
251 return "", "", err
252 }
253 ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
254 ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
255
256 sessions.Lock()
257 defer sessions.Unlock()
258
259 acc, err := ensureAccountSessions(ctx, log, accountName, true)
260 if err != nil {
261 return "", "", err
262 }
263 defer func() {
264 err := acc.Close()
265 log.Check(err, "closing account")
266 }()
267
268 if err := sessionAddToken(ctx, log, acc, &ls); err != nil {
269 return "", "", err
270 }
271
272 return ls.sessionToken, ls.csrfToken, nil
273}
274
275// SessionRemove removes a session from the database and in-memory cache. Future
276// operations using the session token will fail.
277func SessionRemove(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken) error {
278 sessions.Lock()
279 defer sessions.Unlock()
280
281 acc, err := ensureAccountSessions(ctx, log, accountName, true)
282 if err != nil {
283 return err
284 }
285 defer acc.Close()
286
287 ls, ok := sessions.accounts[accountName][sessionToken]
288 if !ok {
289 return fmt.Errorf("unknown session token")
290 }
291
292 if err := acc.DB.Delete(ctx, &ls); err != nil {
293 return err
294 }
295
296 delete(sessions.accounts[accountName], sessionToken)
297 pf := sessions.pendingFlushes[accountName]
298 if pf != nil {
299 delete(pf, sessionToken)
300 }
301
302 return nil
303}
304
305// sessionRemoveAll removes all session tokens for an account. Useful after a password reset.
306func sessionRemoveAll(ctx context.Context, log mlog.Log, tx *bstore.Tx, accountName string) error {
307 sessions.Lock()
308 defer sessions.Unlock()
309
310 if _, err := bstore.QueryTx[LoginSession](tx).Delete(); err != nil {
311 return err
312 }
313
314 sessions.accounts[accountName] = map[SessionToken]LoginSession{}
315 if sessions.pendingFlushes[accountName] != nil {
316 sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
317 }
318
319 return nil
320}
321