diff --git a/server/auth/authenticator.go b/server/auth/authenticator.go index aa163dcc1..e6bab9f36 100644 --- a/server/auth/authenticator.go +++ b/server/auth/authenticator.go @@ -38,23 +38,23 @@ func NewAuthenticator(store *store.Store, secret string) *Authenticator { // AuthenticateBySession validates a session cookie and returns the authenticated user. // // Validation steps: -// 1. Parse cookie value to extract userID and sessionID +// 1. Use session ID to find the user and session details (single DB query) // 2. Verify user exists and is not archived -// 3. Verify session exists in user's sessions list -// 4. Check session hasn't expired (sliding expiration: 14 days from last access) +// 3. Check session hasn't expired (sliding expiration: 14 days from last access) // // Returns the user if authentication succeeds, or an error describing the failure. -func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) { - if sessionCookieValue == "" { - return nil, errors.New("session cookie value not found") +func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionID string) (*store.User, error) { + if sessionID == "" { + return nil, errors.New("session ID not found") } - userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue) + // Find the session and user in a single database query + result, err := a.store.GetUserSessionByID(ctx, sessionID) if err != nil { - return nil, errors.Wrap(err, "invalid session cookie format") + return nil, errors.Wrap(err, "session not found") } - user, err := a.store.GetUser(ctx, &store.FindUser{ID: &userID}) + user, err := a.store.GetUser(ctx, &store.FindUser{ID: &result.UserID}) if err != nil { return nil, errors.Wrap(err, "failed to get user") } @@ -65,13 +65,12 @@ func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookie return nil, errors.New("user is archived") } - sessions, err := a.store.GetUserSessions(ctx, user.ID) - if err != nil { - return nil, errors.Wrap(err, "failed to get user sessions") - } - - if !validateSession(sessionID, sessions) { - return nil, errors.New("invalid or expired session") + // Validate session expiration + if result.Session.LastAccessedTime != nil { + expiration := result.Session.LastAccessedTime.AsTime().Add(SessionSlidingDuration) + if expiration.Before(time.Now()) { + return nil, errors.New("session expired") + } } return user, nil @@ -168,23 +167,6 @@ func (a *Authenticator) AuthorizeAndSetContext(ctx context.Context, procedure st return ctx, nil } -// validateSession checks if a session exists and is still valid. -// Uses sliding expiration: session is valid if last accessed within SessionSlidingDuration. -func validateSession(sessionID string, sessions []*storepb.SessionsUserSetting_Session) bool { - for _, session := range sessions { - if sessionID == session.SessionId { - if session.LastAccessedTime != nil { - expiration := session.LastAccessedTime.AsTime().Add(SessionSlidingDuration) - if expiration.Before(time.Now()) { - return false // Session expired - } - } - return true - } - } - return false // Session not found -} - // validateAccessToken checks if the token exists in the user's access tokens list. // This enables token revocation: deleted tokens are removed from the list. func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting_AccessToken) bool { @@ -215,15 +197,12 @@ type AuthResult struct { // It tries session cookie first, then JWT token. // Returns nil if no valid credentials are provided. // On successful session auth, it also updates the session sliding expiration. -func (a *Authenticator) Authenticate(ctx context.Context, sessionCookie, authHeader string) *AuthResult { +func (a *Authenticator) Authenticate(ctx context.Context, sessionID, authHeader string) *AuthResult { // Try session cookie authentication first - if sessionCookie != "" { - user, err := a.AuthenticateBySession(ctx, sessionCookie) + if sessionID != "" { + user, err := a.AuthenticateBySession(ctx, sessionID) if err == nil && user != nil { - _, sessionID, parseErr := ParseSessionCookieValue(sessionCookie) - if parseErr == nil && sessionID != "" { - a.UpdateSessionLastAccessed(ctx, user.ID, sessionID) - } + a.UpdateSessionLastAccessed(ctx, user.ID, sessionID) return &AuthResult{User: user, SessionID: sessionID} } } diff --git a/server/auth/token.go b/server/auth/token.go index 813e4824b..5d0ff69a2 100644 --- a/server/auth/token.go +++ b/server/auth/token.go @@ -11,11 +11,9 @@ package auth import ( "fmt" - "strings" "time" "github.com/golang-jwt/jwt/v5" - "github.com/pkg/errors" "github.com/usememos/memos/internal/util" ) @@ -40,7 +38,7 @@ const ( SessionSlidingDuration = 14 * 24 * time.Hour // SessionCookieName is the HTTP cookie name used to store session information. - // Cookie value format: {userID}-{sessionID}. + // Cookie value is the session ID (UUID). SessionCookieName = "user_session" ) @@ -108,37 +106,7 @@ func generateToken(username string, userID int32, audience string, expirationTim // // Uses UUID v4 (random) for high entropy and uniqueness. // Session IDs are stored in user settings and used to identify browser sessions. +// The session ID is stored directly in the cookie as the cookie value. func GenerateSessionID() string { return util.GenUUID() } - -// BuildSessionCookieValue creates the session cookie value. -// -// Format: {userID}-{sessionID} -// Example: "123-550e8400-e29b-41d4-a716-446655440000" -// -// This format allows quick extraction of both user ID and session ID -// from the cookie without database lookup during authentication. -func BuildSessionCookieValue(userID int32, sessionID string) string { - return fmt.Sprintf("%d-%s", userID, sessionID) -} - -// ParseSessionCookieValue extracts user ID and session ID from cookie value. -// -// Input format: "{userID}-{sessionID}" -// Returns: (userID, sessionID, error) -// -// Example: "123-550e8400-..." → (123, "550e8400-...", nil). -func ParseSessionCookieValue(cookieValue string) (int32, string, error) { - parts := strings.SplitN(cookieValue, "-", 2) - if len(parts) != 2 { - return 0, "", errors.New("invalid session cookie format") - } - - userID, err := util.ConvertStringToInt32(parts[0]) - if err != nil { - return 0, "", errors.Errorf("invalid user ID in session cookie: %v", err) - } - - return userID, parts[1], nil -} diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 3db5296b4..d78e709de 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -230,9 +230,8 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim slog.Error("failed to track user session", "error", err) } - // Set session cookie for web use (format: userID-sessionID) - sessionCookieValue := auth.BuildSessionCookieValue(user.ID, sessionID) - sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime) + // Set session cookie for web use + sessionCookie, err := s.buildSessionCookie(ctx, sessionID, expireTime) if err != nil { return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err) } diff --git a/store/db/mysql/user_setting.go b/store/db/mysql/user_setting.go index d51228ba6..16545f0fc 100644 --- a/store/db/mysql/user_setting.go +++ b/store/db/mysql/user_setting.go @@ -4,6 +4,8 @@ import ( "context" "strings" + "github.com/pkg/errors" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -54,3 +56,42 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) return userSettingList, nil } + +func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) { + // Query user_setting that contains this sessionID in the sessions array + // Use JSON_SEARCH to check if sessionID exists in the array + query := ` + SELECT + user_id, + value + FROM user_setting + WHERE ` + "`key`" + ` = 'SESSIONS' + AND JSON_SEARCH(value, 'one', ?, NULL, '$.sessions[*].sessionId') IS NOT NULL + ` + + var userID int32 + var sessionsJSON string + + err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON) + if err != nil { + return nil, err + } + + // Parse the entire sessions list using protobuf unmarshaler + sessionsUserSetting := &storepb.SessionsUserSetting{} + if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil { + return nil, err + } + + // Find the specific session by ID + for _, session := range sessionsUserSetting.Sessions { + if session.SessionId == sessionID { + return &store.UserSessionQueryResult{ + UserID: userID, + Session: session, + }, nil + } + } + + return nil, errors.New("session not found") +} diff --git a/store/db/postgres/user_setting.go b/store/db/postgres/user_setting.go index 04aec63a0..478bfe00d 100644 --- a/store/db/postgres/user_setting.go +++ b/store/db/postgres/user_setting.go @@ -4,6 +4,8 @@ import ( "context" "strings" + "github.com/pkg/errors" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -67,3 +69,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) return userSettingList, nil } + +func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) { + // Query user_setting that contains this sessionID in the sessions array + // Use EXISTS with jsonb_array_elements to check array membership + query := ` + SELECT + user_setting.user_id, + user_setting.value + FROM user_setting + WHERE user_setting.key = 'SESSIONS' + AND EXISTS ( + SELECT 1 + FROM jsonb_array_elements(user_setting.value::jsonb->'sessions') AS session + WHERE session->>'sessionId' = $1 + ) + ` + + var userID int32 + var sessionsJSON string + + err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON) + if err != nil { + return nil, err + } + + // Parse the entire sessions list using protobuf unmarshaler + sessionsUserSetting := &storepb.SessionsUserSetting{} + if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil { + return nil, err + } + + // Find the specific session by ID + for _, session := range sessionsUserSetting.Sessions { + if session.SessionId == sessionID { + return &store.UserSessionQueryResult{ + UserID: userID, + Session: session, + }, nil + } + } + + return nil, errors.New("session not found") +} diff --git a/store/db/sqlite/user_setting.go b/store/db/sqlite/user_setting.go index c6ca6f3f0..ae2d9805c 100644 --- a/store/db/sqlite/user_setting.go +++ b/store/db/sqlite/user_setting.go @@ -4,6 +4,8 @@ import ( "context" "strings" + "github.com/pkg/errors" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -66,3 +68,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) return userSettingList, nil } + +func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) { + // Query user_setting that contains this sessionID in the sessions array + // Use EXISTS with json_each to properly check array membership + query := ` + SELECT + user_setting.user_id, + user_setting.value + FROM user_setting + WHERE user_setting.key = 'SESSIONS' + AND EXISTS ( + SELECT 1 + FROM json_each(json_extract(user_setting.value, '$.sessions')) AS session + WHERE json_extract(session.value, '$.sessionId') = ? + ) + ` + + var userID int32 + var sessionsJSON string + + err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON) + if err != nil { + return nil, err + } + + // Parse the entire sessions list using protobuf unmarshaler + sessionsUserSetting := &storepb.SessionsUserSetting{} + if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil { + return nil, err + } + + // Find the specific session by ID + for _, session := range sessionsUserSetting.Sessions { + if session.SessionId == sessionID { + return &store.UserSessionQueryResult{ + UserID: userID, + Session: session, + }, nil + } + } + + return nil, errors.New("session not found") +} diff --git a/store/driver.go b/store/driver.go index f13a23ee9..e06209cae 100644 --- a/store/driver.go +++ b/store/driver.go @@ -48,6 +48,7 @@ type Driver interface { // UserSetting model related methods. UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) + GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error) // IdentityProvider model related methods. CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) diff --git a/store/user_setting.go b/store/user_setting.go index 3c48b971c..fcdb0a951 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -21,6 +21,12 @@ type FindUserSetting struct { Key storepb.UserSetting_Key } +// UserSessionQueryResult contains the result of querying a single session by ID. +type UserSessionQueryResult struct { + UserID int32 + Session *storepb.SessionsUserSetting_Session +} + func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { userSettingRaw, err := convertUserSettingToRaw(upsert) if err != nil { @@ -241,6 +247,12 @@ func (s *Store) UpdateUserSessionLastAccessed(ctx context.Context, userID int32, return err } +// GetUserSessionByID returns the session details for the given session ID. +// Uses database-specific JSON queries for efficient lookup without loading all sessions. +func (s *Store) GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error) { + return s.driver.GetUserSessionByID(ctx, sessionID) +} + // GetUserWebhooks returns the webhooks of the user. func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) { userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{