5 cryptorand "crypto/rand"
14 "github.com/mjl-/bstore"
16 "github.com/mjl-/mox/metrics"
17 "github.com/mjl-/mox/mlog"
18 "github.com/mjl-/mox/mox-"
21const sessionsPerAccount = 100 // We remove the oldest when 100th is added.
22const sessionLifetime = 24 * time.Hour // Extended automatically by use.
23const sessionWriteDelay = 5 * time.Minute // Per account, for coalescing writes.
25var sessions = struct {
28 // For each account, we keep all sessions (with fixed maximum number) in memory. If
29 // the map for an account is nil, it is initialized from the database on first use.
30 accounts map[string]map[SessionToken]LoginSession
32 // We flush sessions with extended expiration timestamp to disk with a delay, to
33 // coalesce potentially many changes. The delay is short enough that we don't have
34 // to care about flushing to disk on shutdown.
35 pendingFlushes map[string]map[SessionToken]struct{}
37 accounts: map[string]map[SessionToken]LoginSession{},
38 pendingFlushes: map[string]map[SessionToken]struct{}{},
41// Ensure sessions for account are initialized from database. If openAccount is
42// set, an account is returned on success.
44// must be called with sessions lock held.
45func ensureAccountSessions(ctx context.Context, log mlog.Log, accountName string, openAccount bool) (acc *Account, rerr error) {
47 if !openAccount && acc != nil {
48 if err := acc.Close(); err != nil && rerr == nil {
49 rerr = fmt.Errorf("closing account: %v", err)
54 accSessions := sessions.accounts[accountName]
55 if accSessions == nil {
57 acc, err = OpenAccount(log, accountName, openAccount)
62 // We still hold the lock, not great...
64 accSessions = map[SessionToken]LoginSession{}
65 err = bstore.QueryDB[LoginSession](ctx, acc.DB).ForEach(func(ls LoginSession) error {
66 // We keep strings around for easy comparison.
67 ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
68 ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
70 accSessions[ls.sessionToken] = ls
77 sessions.accounts[accountName] = accSessions
79 if acc == nil && openAccount {
80 acc, rerr = OpenAccount(log, accountName, true)
85// SessionUse checks if a session is valid. If csrfToken is the empty string, no
86// CSRF check is done. Otherwise it must be the csrf token associated with the
88func SessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
90 defer sessions.Unlock()
92 if _, err := ensureAccountSessions(ctx, log, accountName, false); err != nil {
93 return LoginSession{}, err
96 return sessionUse(ctx, log, accountName, sessionToken, csrfToken)
99// must be called with sessions lock held.
100func sessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
102 ls, ok := sessions.accounts[accountName][sessionToken]
104 return LoginSession{}, fmt.Errorf("unknown session token")
105 } else if time.Until(ls.Expires) < 0 {
106 return LoginSession{}, fmt.Errorf("session expired (after 24 hours inactivity)")
107 } else if csrfToken != "" && csrfToken != ls.csrfToken {
108 return LoginSession{}, fmt.Errorf("mismatch between csrf and session tokens")
112 ls.Expires = time.Now().Add(sessionLifetime)
113 sessions.accounts[accountName][sessionToken] = ls
115 // If we haven't scheduled a flush to database yet, schedule one now.
116 if sessions.pendingFlushes[accountName] == nil {
117 sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
119 pkglog := mlog.New("store", nil)
124 pkglog.Error("recover from panic", slog.Any("panic", x))
126 metrics.PanicInc(metrics.Store)
130 time.Sleep(sessionWriteDelay)
131 sessionsDelayedFlush(pkglog, accountName)
134 sessions.pendingFlushes[accountName][ls.sessionToken] = struct{}{}
139// wait, then flush all changed sessions for an account.
140func sessionsDelayedFlush(log mlog.Log, accountName string) {
142 defer sessions.Unlock()
144 sessionTokens := sessions.pendingFlushes[accountName]
145 delete(sessions.pendingFlushes, accountName)
147 _, ok := sessions.accounts[accountName]
149 // Account may have been removed. Nothing to flush.
153 acc, err := OpenAccount(log, accountName, false)
154 if err != nil && errors.Is(err, ErrAccountUnknown) {
155 // Account may have been removed. Nothing to flush.
156 log.Infox("flushing sessions for account", err, slog.String("account", accountName))
160 log.Errorx("open account for flushing changed session tokens", err, slog.String("account", accountName))
165 log.Check(err, "closing account")
168 err = acc.DB.Write(mox.Context, func(tx *bstore.Tx) error {
169 for sessionToken := range sessionTokens {
170 ls, ok := sessions.accounts[accountName][sessionToken]
172 return fmt.Errorf("unknown session token to flush")
174 if err := tx.Update(&ls); err != nil {
180 log.Check(err, "flushing changed sessions for account", slog.String("account", accountName))
183// SessionAddTokens adds a prepared or pre-existing LoginSession to the database and
184// cache. Can be used to restore a session token that was used to reset a password.
185func SessionAddToken(ctx context.Context, log mlog.Log, ls *LoginSession) error {
187 defer sessions.Unlock()
189 acc, err := ensureAccountSessions(ctx, log, ls.AccountName, true)
195 log.Check(err, "closing account after adding session token")
198 return sessionAddToken(ctx, log, acc, ls)
201// caller must hold sessions lock.
202func sessionAddToken(ctx context.Context, log mlog.Log, acc *Account, ls *LoginSession) error {
205 err := acc.DB.Write(ctx, func(tx *bstore.Tx) error {
206 // Remove sessions if we have too many, starting with expired sessions, and
207 // removing the oldest if needed.
208 if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
209 var oldest LoginSession
210 for _, ols := range sessions.accounts[ls.AccountName] {
211 if time.Until(ols.Expires) < 0 {
212 if err := tx.Delete(&ols); err != nil {
215 delete(sessions.accounts[ls.AccountName], ols.sessionToken)
218 if oldest.ID == 0 || ols.Expires.Before(oldest.Expires) {
222 if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
223 if err := tx.Delete(&oldest); err != nil {
226 delete(sessions.accounts[ls.AccountName], oldest.sessionToken)
230 if err := tx.Insert(ls); err != nil {
231 return fmt.Errorf("insert: %v", err)
238 sessions.accounts[ls.AccountName][ls.sessionToken] = *ls
242// SessionAdd creates a new session token, with csrf token, and adds it to the
243// database and in-memory session cache. If there are too many sessions, the oldest
245func SessionAdd(ctx context.Context, log mlog.Log, accountName, loginAddress string) (session SessionToken, csrf CSRFToken, rerr error) {
246 // Prepare new LoginSession.
247 ls := LoginSession{0, time.Time{}, time.Now().Add(sessionLifetime), [16]byte{}, [16]byte{}, accountName, loginAddress, "", ""}
248 if _, err := cryptorand.Read(ls.SessionTokenBinary[:]); err != nil {
251 if _, err := cryptorand.Read(ls.CSRFTokenBinary[:]); err != nil {
254 ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
255 ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
258 defer sessions.Unlock()
260 acc, err := ensureAccountSessions(ctx, log, accountName, true)
266 log.Check(err, "closing account")
269 if err := sessionAddToken(ctx, log, acc, &ls); err != nil {
273 return ls.sessionToken, ls.csrfToken, nil
276// SessionRemove removes a session from the database and in-memory cache. Future
277// operations using the session token will fail.
278func SessionRemove(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken) error {
280 defer sessions.Unlock()
282 acc, err := ensureAccountSessions(ctx, log, accountName, true)
288 log.Check(err, "closing account")
291 ls, ok := sessions.accounts[accountName][sessionToken]
293 return fmt.Errorf("unknown session token")
296 if err := acc.DB.Delete(ctx, &ls); err != nil {
300 delete(sessions.accounts[accountName], sessionToken)
301 pf := sessions.pendingFlushes[accountName]
303 delete(pf, sessionToken)
309// sessionRemoveAll removes all session tokens for an account. Useful after a password reset.
310func sessionRemoveAll(ctx context.Context, log mlog.Log, tx *bstore.Tx, accountName string) error {
312 defer sessions.Unlock()
314 if _, err := bstore.QueryTx[LoginSession](tx).Delete(); err != nil {
318 sessions.accounts[accountName] = map[SessionToken]LoginSession{}
319 if sessions.pendingFlushes[accountName] != nil {
320 sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}