refactor(auth): streamline session authentication and cookie handling

This commit is contained in:
Johnny 2025-12-16 22:23:59 +08:00
parent 87b8c2b2d2
commit 40e129b8af
8 changed files with 167 additions and 77 deletions

View File

@ -38,23 +38,23 @@ func NewAuthenticator(store *store.Store, secret string) *Authenticator {
// AuthenticateBySession validates a session cookie and returns the authenticated user.
//
// Validation steps:
// 1. Parse cookie value to extract userID and sessionID
// 1. Use session ID to find the user and session details (single DB query)
// 2. Verify user exists and is not archived
// 3. Verify session exists in user's sessions list
// 4. Check session hasn't expired (sliding expiration: 14 days from last access)
// 3. Check session hasn't expired (sliding expiration: 14 days from last access)
//
// Returns the user if authentication succeeds, or an error describing the failure.
func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) {
if sessionCookieValue == "" {
return nil, errors.New("session cookie value not found")
func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionID string) (*store.User, error) {
if sessionID == "" {
return nil, errors.New("session ID not found")
}
userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue)
// Find the session and user in a single database query
result, err := a.store.GetUserSessionByID(ctx, sessionID)
if err != nil {
return nil, errors.Wrap(err, "invalid session cookie format")
return nil, errors.Wrap(err, "session not found")
}
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &userID})
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &result.UserID})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
@ -65,13 +65,12 @@ func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookie
return nil, errors.New("user is archived")
}
sessions, err := a.store.GetUserSessions(ctx, user.ID)
if err != nil {
return nil, errors.Wrap(err, "failed to get user sessions")
}
if !validateSession(sessionID, sessions) {
return nil, errors.New("invalid or expired session")
// Validate session expiration
if result.Session.LastAccessedTime != nil {
expiration := result.Session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
if expiration.Before(time.Now()) {
return nil, errors.New("session expired")
}
}
return user, nil
@ -168,23 +167,6 @@ func (a *Authenticator) AuthorizeAndSetContext(ctx context.Context, procedure st
return ctx, nil
}
// validateSession checks if a session exists and is still valid.
// Uses sliding expiration: session is valid if last accessed within SessionSlidingDuration.
func validateSession(sessionID string, sessions []*storepb.SessionsUserSetting_Session) bool {
for _, session := range sessions {
if sessionID == session.SessionId {
if session.LastAccessedTime != nil {
expiration := session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
if expiration.Before(time.Now()) {
return false // Session expired
}
}
return true
}
}
return false // Session not found
}
// validateAccessToken checks if the token exists in the user's access tokens list.
// This enables token revocation: deleted tokens are removed from the list.
func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
@ -215,15 +197,12 @@ type AuthResult struct {
// It tries session cookie first, then JWT token.
// Returns nil if no valid credentials are provided.
// On successful session auth, it also updates the session sliding expiration.
func (a *Authenticator) Authenticate(ctx context.Context, sessionCookie, authHeader string) *AuthResult {
func (a *Authenticator) Authenticate(ctx context.Context, sessionID, authHeader string) *AuthResult {
// Try session cookie authentication first
if sessionCookie != "" {
user, err := a.AuthenticateBySession(ctx, sessionCookie)
if sessionID != "" {
user, err := a.AuthenticateBySession(ctx, sessionID)
if err == nil && user != nil {
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookie)
if parseErr == nil && sessionID != "" {
a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
}
a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
return &AuthResult{User: user, SessionID: sessionID}
}
}

View File

@ -11,11 +11,9 @@ package auth
import (
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
@ -40,7 +38,7 @@ const (
SessionSlidingDuration = 14 * 24 * time.Hour
// SessionCookieName is the HTTP cookie name used to store session information.
// Cookie value format: {userID}-{sessionID}.
// Cookie value is the session ID (UUID).
SessionCookieName = "user_session"
)
@ -108,37 +106,7 @@ func generateToken(username string, userID int32, audience string, expirationTim
//
// Uses UUID v4 (random) for high entropy and uniqueness.
// Session IDs are stored in user settings and used to identify browser sessions.
// The session ID is stored directly in the cookie as the cookie value.
func GenerateSessionID() string {
return util.GenUUID()
}
// BuildSessionCookieValue creates the session cookie value.
//
// Format: {userID}-{sessionID}
// Example: "123-550e8400-e29b-41d4-a716-446655440000"
//
// This format allows quick extraction of both user ID and session ID
// from the cookie without database lookup during authentication.
func BuildSessionCookieValue(userID int32, sessionID string) string {
return fmt.Sprintf("%d-%s", userID, sessionID)
}
// ParseSessionCookieValue extracts user ID and session ID from cookie value.
//
// Input format: "{userID}-{sessionID}"
// Returns: (userID, sessionID, error)
//
// Example: "123-550e8400-..." → (123, "550e8400-...", nil).
func ParseSessionCookieValue(cookieValue string) (int32, string, error) {
parts := strings.SplitN(cookieValue, "-", 2)
if len(parts) != 2 {
return 0, "", errors.New("invalid session cookie format")
}
userID, err := util.ConvertStringToInt32(parts[0])
if err != nil {
return 0, "", errors.Errorf("invalid user ID in session cookie: %v", err)
}
return userID, parts[1], nil
}

View File

@ -230,9 +230,8 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim
slog.Error("failed to track user session", "error", err)
}
// Set session cookie for web use (format: userID-sessionID)
sessionCookieValue := auth.BuildSessionCookieValue(user.ID, sessionID)
sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime)
// Set session cookie for web use
sessionCookie, err := s.buildSessionCookie(ctx, sessionID, expireTime)
if err != nil {
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
}

View File

@ -4,6 +4,8 @@ import (
"context"
"strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -54,3 +56,42 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil
}
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
// Query user_setting that contains this sessionID in the sessions array
// Use JSON_SEARCH to check if sessionID exists in the array
query := `
SELECT
user_id,
value
FROM user_setting
WHERE ` + "`key`" + ` = 'SESSIONS'
AND JSON_SEARCH(value, 'one', ?, NULL, '$.sessions[*].sessionId') IS NOT NULL
`
var userID int32
var sessionsJSON string
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
if err != nil {
return nil, err
}
// Parse the entire sessions list using protobuf unmarshaler
sessionsUserSetting := &storepb.SessionsUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
return nil, err
}
// Find the specific session by ID
for _, session := range sessionsUserSetting.Sessions {
if session.SessionId == sessionID {
return &store.UserSessionQueryResult{
UserID: userID,
Session: session,
}, nil
}
}
return nil, errors.New("session not found")
}

View File

@ -4,6 +4,8 @@ import (
"context"
"strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -67,3 +69,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil
}
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
// Query user_setting that contains this sessionID in the sessions array
// Use EXISTS with jsonb_array_elements to check array membership
query := `
SELECT
user_setting.user_id,
user_setting.value
FROM user_setting
WHERE user_setting.key = 'SESSIONS'
AND EXISTS (
SELECT 1
FROM jsonb_array_elements(user_setting.value::jsonb->'sessions') AS session
WHERE session->>'sessionId' = $1
)
`
var userID int32
var sessionsJSON string
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
if err != nil {
return nil, err
}
// Parse the entire sessions list using protobuf unmarshaler
sessionsUserSetting := &storepb.SessionsUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
return nil, err
}
// Find the specific session by ID
for _, session := range sessionsUserSetting.Sessions {
if session.SessionId == sessionID {
return &store.UserSessionQueryResult{
UserID: userID,
Session: session,
}, nil
}
}
return nil, errors.New("session not found")
}

View File

@ -4,6 +4,8 @@ import (
"context"
"strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -66,3 +68,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil
}
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
// Query user_setting that contains this sessionID in the sessions array
// Use EXISTS with json_each to properly check array membership
query := `
SELECT
user_setting.user_id,
user_setting.value
FROM user_setting
WHERE user_setting.key = 'SESSIONS'
AND EXISTS (
SELECT 1
FROM json_each(json_extract(user_setting.value, '$.sessions')) AS session
WHERE json_extract(session.value, '$.sessionId') = ?
)
`
var userID int32
var sessionsJSON string
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
if err != nil {
return nil, err
}
// Parse the entire sessions list using protobuf unmarshaler
sessionsUserSetting := &storepb.SessionsUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
return nil, err
}
// Find the specific session by ID
for _, session := range sessionsUserSetting.Sessions {
if session.SessionId == sessionID {
return &store.UserSessionQueryResult{
UserID: userID,
Session: session,
}, nil
}
}
return nil, errors.New("session not found")
}

View File

@ -48,6 +48,7 @@ type Driver interface {
// UserSetting model related methods.
UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error)
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error)
GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error)
// IdentityProvider model related methods.
CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error)

View File

@ -21,6 +21,12 @@ type FindUserSetting struct {
Key storepb.UserSetting_Key
}
// UserSessionQueryResult contains the result of querying a single session by ID.
type UserSessionQueryResult struct {
UserID int32
Session *storepb.SessionsUserSetting_Session
}
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
userSettingRaw, err := convertUserSettingToRaw(upsert)
if err != nil {
@ -241,6 +247,12 @@ func (s *Store) UpdateUserSessionLastAccessed(ctx context.Context, userID int32,
return err
}
// GetUserSessionByID returns the session details for the given session ID.
// Uses database-specific JSON queries for efficient lookup without loading all sessions.
func (s *Store) GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error) {
return s.driver.GetUserSessionByID(ctx, sessionID)
}
// GetUserWebhooks returns the webhooks of the user.
func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) {
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{