1package store
2
3import (
4 "context"
5 cryptorand "crypto/rand"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "log/slog"
10 "runtime/debug"
11 "sync"
12 "time"
13
14 "github.com/mjl-/bstore"
15
16 "github.com/mjl-/mox/metrics"
17 "github.com/mjl-/mox/mlog"
18 "github.com/mjl-/mox/mox-"
19)
20
21const sessionsPerAccount = 100 // We remove the oldest when 100th is added.
22const sessionLifetime = 24 * time.Hour // Extended automatically by use.
23const sessionWriteDelay = 5 * time.Minute // Per account, for coalescing writes.
24
25var sessions = struct {
26 sync.Mutex
27
28 // For each account, we keep all sessions (with fixed maximum number) in memory. If
29 // the map for an account is nil, it is initialized from the database on first use.
30 accounts map[string]map[SessionToken]LoginSession
31
32 // We flush sessions with extended expiration timestamp to disk with a delay, to
33 // coalesce potentially many changes. The delay is short enough that we don't have
34 // to care about flushing to disk on shutdown.
35 pendingFlushes map[string]map[SessionToken]struct{}
36}{
37 accounts: map[string]map[SessionToken]LoginSession{},
38 pendingFlushes: map[string]map[SessionToken]struct{}{},
39}
40
41// Ensure sessions for account are initialized from database. If openAccount is
42// set, an account is returned on success.
43//
44// must be called with sessions lock held.
45func ensureAccountSessions(ctx context.Context, log mlog.Log, accountName string, openAccount bool) (acc *Account, rerr error) {
46 defer func() {
47 if !openAccount && acc != nil {
48 if err := acc.Close(); err != nil && rerr == nil {
49 rerr = fmt.Errorf("closing account: %v", err)
50 }
51 acc = nil
52 }
53 }()
54 accSessions := sessions.accounts[accountName]
55 if accSessions == nil {
56 var err error
57 acc, err = OpenAccount(log, accountName, openAccount)
58 if err != nil {
59 return nil, err
60 }
61
62 // We still hold the lock, not great...
63
64 accSessions = map[SessionToken]LoginSession{}
65 err = bstore.QueryDB[LoginSession](ctx, acc.DB).ForEach(func(ls LoginSession) error {
66 // We keep strings around for easy comparison.
67 ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
68 ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
69
70 accSessions[ls.sessionToken] = ls
71 return nil
72 })
73 if err != nil {
74 return nil, err
75 }
76
77 sessions.accounts[accountName] = accSessions
78 }
79 if acc == nil && openAccount {
80 acc, rerr = OpenAccount(log, accountName, true)
81 }
82 return
83}
84
85// SessionUse checks if a session is valid. If csrfToken is the empty string, no
86// CSRF check is done. Otherwise it must be the csrf token associated with the
87// session token.
88func SessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
89 sessions.Lock()
90 defer sessions.Unlock()
91
92 if _, err := ensureAccountSessions(ctx, log, accountName, false); err != nil {
93 return LoginSession{}, err
94 }
95
96 return sessionUse(ctx, log, accountName, sessionToken, csrfToken)
97}
98
99// must be called with sessions lock held.
100func sessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
101 // Check if valid.
102 ls, ok := sessions.accounts[accountName][sessionToken]
103 if !ok {
104 return LoginSession{}, fmt.Errorf("unknown session token")
105 } else if time.Until(ls.Expires) < 0 {
106 return LoginSession{}, fmt.Errorf("session expired (after 24 hours inactivity)")
107 } else if csrfToken != "" && csrfToken != ls.csrfToken {
108 return LoginSession{}, fmt.Errorf("mismatch between csrf and session tokens")
109 }
110
111 // Extend lifetime.
112 ls.Expires = time.Now().Add(sessionLifetime)
113 sessions.accounts[accountName][sessionToken] = ls
114
115 // If we haven't scheduled a flush to database yet, schedule one now.
116 if sessions.pendingFlushes[accountName] == nil {
117 sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
118 go func() {
119 pkglog := mlog.New("store", nil)
120
121 defer func() {
122 x := recover()
123 if x != nil {
124 pkglog.Error("recover from panic", slog.Any("panic", x))
125 debug.PrintStack()
126 metrics.PanicInc(metrics.Store)
127 }
128 }()
129
130 time.Sleep(sessionWriteDelay)
131 sessionsDelayedFlush(pkglog, accountName)
132 }()
133 }
134 sessions.pendingFlushes[accountName][ls.sessionToken] = struct{}{}
135
136 return ls, nil
137}
138
139// wait, then flush all changed sessions for an account.
140func sessionsDelayedFlush(log mlog.Log, accountName string) {
141 sessions.Lock()
142 defer sessions.Unlock()
143
144 sessionTokens := sessions.pendingFlushes[accountName]
145 delete(sessions.pendingFlushes, accountName)
146
147 _, ok := sessions.accounts[accountName]
148 if !ok {
149 // Account may have been removed. Nothing to flush.
150 return
151 }
152
153 acc, err := OpenAccount(log, accountName, false)
154 if err != nil && errors.Is(err, ErrAccountUnknown) {
155 // Account may have been removed. Nothing to flush.
156 log.Infox("flushing sessions for account", err, slog.String("account", accountName))
157 return
158 }
159 if err != nil {
160 log.Errorx("open account for flushing changed session tokens", err, slog.String("account", accountName))
161 return
162 }
163 defer func() {
164 err := acc.Close()
165 log.Check(err, "closing account")
166 }()
167
168 err = acc.DB.Write(mox.Context, func(tx *bstore.Tx) error {
169 for sessionToken := range sessionTokens {
170 ls, ok := sessions.accounts[accountName][sessionToken]
171 if !ok {
172 return fmt.Errorf("unknown session token to flush")
173 }
174 if err := tx.Update(&ls); err != nil {
175 return err
176 }
177 }
178 return nil
179 })
180 log.Check(err, "flushing changed sessions for account", slog.String("account", accountName))
181}
182
183// SessionAddTokens adds a prepared or pre-existing LoginSession to the database and
184// cache. Can be used to restore a session token that was used to reset a password.
185func SessionAddToken(ctx context.Context, log mlog.Log, ls *LoginSession) error {
186 sessions.Lock()
187 defer sessions.Unlock()
188
189 acc, err := ensureAccountSessions(ctx, log, ls.AccountName, true)
190 if err != nil {
191 return err
192 }
193 defer func() {
194 err := acc.Close()
195 log.Check(err, "closing account after adding session token")
196 }()
197
198 return sessionAddToken(ctx, log, acc, ls)
199}
200
201// caller must hold sessions lock.
202func sessionAddToken(ctx context.Context, log mlog.Log, acc *Account, ls *LoginSession) error {
203 ls.ID = 0
204
205 err := acc.DB.Write(ctx, func(tx *bstore.Tx) error {
206 // Remove sessions if we have too many, starting with expired sessions, and
207 // removing the oldest if needed.
208 if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
209 var oldest LoginSession
210 for _, ols := range sessions.accounts[ls.AccountName] {
211 if time.Until(ols.Expires) < 0 {
212 if err := tx.Delete(&ols); err != nil {
213 return err
214 }
215 delete(sessions.accounts[ls.AccountName], ols.sessionToken)
216 continue
217 }
218 if oldest.ID == 0 || ols.Expires.Before(oldest.Expires) {
219 oldest = ols
220 }
221 }
222 if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
223 if err := tx.Delete(&oldest); err != nil {
224 return err
225 }
226 delete(sessions.accounts[ls.AccountName], oldest.sessionToken)
227 }
228 }
229
230 if err := tx.Insert(ls); err != nil {
231 return fmt.Errorf("insert: %v", err)
232 }
233 return nil
234 })
235 if err != nil {
236 return err
237 }
238 sessions.accounts[ls.AccountName][ls.sessionToken] = *ls
239 return nil
240}
241
242// SessionAdd creates a new session token, with csrf token, and adds it to the
243// database and in-memory session cache. If there are too many sessions, the oldest
244// is removed.
245func SessionAdd(ctx context.Context, log mlog.Log, accountName, loginAddress string) (session SessionToken, csrf CSRFToken, rerr error) {
246 // Prepare new LoginSession.
247 ls := LoginSession{0, time.Time{}, time.Now().Add(sessionLifetime), [16]byte{}, [16]byte{}, accountName, loginAddress, "", ""}
248 if _, err := cryptorand.Read(ls.SessionTokenBinary[:]); err != nil {
249 return "", "", err
250 }
251 if _, err := cryptorand.Read(ls.CSRFTokenBinary[:]); err != nil {
252 return "", "", err
253 }
254 ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
255 ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
256
257 sessions.Lock()
258 defer sessions.Unlock()
259
260 acc, err := ensureAccountSessions(ctx, log, accountName, true)
261 if err != nil {
262 return "", "", err
263 }
264 defer func() {
265 err := acc.Close()
266 log.Check(err, "closing account")
267 }()
268
269 if err := sessionAddToken(ctx, log, acc, &ls); err != nil {
270 return "", "", err
271 }
272
273 return ls.sessionToken, ls.csrfToken, nil
274}
275
276// SessionRemove removes a session from the database and in-memory cache. Future
277// operations using the session token will fail.
278func SessionRemove(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken) error {
279 sessions.Lock()
280 defer sessions.Unlock()
281
282 acc, err := ensureAccountSessions(ctx, log, accountName, true)
283 if err != nil {
284 return err
285 }
286 defer func() {
287 err := acc.Close()
288 log.Check(err, "closing account")
289 }()
290
291 ls, ok := sessions.accounts[accountName][sessionToken]
292 if !ok {
293 return fmt.Errorf("unknown session token")
294 }
295
296 if err := acc.DB.Delete(ctx, &ls); err != nil {
297 return err
298 }
299
300 delete(sessions.accounts[accountName], sessionToken)
301 pf := sessions.pendingFlushes[accountName]
302 if pf != nil {
303 delete(pf, sessionToken)
304 }
305
306 return nil
307}
308
309// sessionRemoveAll removes all session tokens for an account. Useful after a password reset.
310func sessionRemoveAll(ctx context.Context, log mlog.Log, tx *bstore.Tx, accountName string) error {
311 sessions.Lock()
312 defer sessions.Unlock()
313
314 if _, err := bstore.QueryTx[LoginSession](tx).Delete(); err != nil {
315 return err
316 }
317
318 sessions.accounts[accountName] = map[SessionToken]LoginSession{}
319 if sessions.pendingFlushes[accountName] != nil {
320 sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
321 }
322
323 return nil
324}
325