refactor(auth): streamline session authentication and cookie handling

This commit is contained in:
Johnny 2025-12-16 22:23:59 +08:00
parent 87b8c2b2d2
commit 40e129b8af
8 changed files with 167 additions and 77 deletions

View File

@ -38,23 +38,23 @@ func NewAuthenticator(store *store.Store, secret string) *Authenticator {
// AuthenticateBySession validates a session cookie and returns the authenticated user. // AuthenticateBySession validates a session cookie and returns the authenticated user.
// //
// Validation steps: // Validation steps:
// 1. Parse cookie value to extract userID and sessionID // 1. Use session ID to find the user and session details (single DB query)
// 2. Verify user exists and is not archived // 2. Verify user exists and is not archived
// 3. Verify session exists in user's sessions list // 3. Check session hasn't expired (sliding expiration: 14 days from last access)
// 4. Check session hasn't expired (sliding expiration: 14 days from last access)
// //
// Returns the user if authentication succeeds, or an error describing the failure. // Returns the user if authentication succeeds, or an error describing the failure.
func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) { func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionID string) (*store.User, error) {
if sessionCookieValue == "" { if sessionID == "" {
return nil, errors.New("session cookie value not found") return nil, errors.New("session ID not found")
} }
userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue) // Find the session and user in a single database query
result, err := a.store.GetUserSessionByID(ctx, sessionID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "invalid session cookie format") return nil, errors.Wrap(err, "session not found")
} }
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &userID}) user, err := a.store.GetUser(ctx, &store.FindUser{ID: &result.UserID})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get user") return nil, errors.Wrap(err, "failed to get user")
} }
@ -65,13 +65,12 @@ func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookie
return nil, errors.New("user is archived") return nil, errors.New("user is archived")
} }
sessions, err := a.store.GetUserSessions(ctx, user.ID) // Validate session expiration
if err != nil { if result.Session.LastAccessedTime != nil {
return nil, errors.Wrap(err, "failed to get user sessions") expiration := result.Session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
} if expiration.Before(time.Now()) {
return nil, errors.New("session expired")
if !validateSession(sessionID, sessions) { }
return nil, errors.New("invalid or expired session")
} }
return user, nil return user, nil
@ -168,23 +167,6 @@ func (a *Authenticator) AuthorizeAndSetContext(ctx context.Context, procedure st
return ctx, nil return ctx, nil
} }
// validateSession checks if a session exists and is still valid.
// Uses sliding expiration: session is valid if last accessed within SessionSlidingDuration.
func validateSession(sessionID string, sessions []*storepb.SessionsUserSetting_Session) bool {
for _, session := range sessions {
if sessionID == session.SessionId {
if session.LastAccessedTime != nil {
expiration := session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
if expiration.Before(time.Now()) {
return false // Session expired
}
}
return true
}
}
return false // Session not found
}
// validateAccessToken checks if the token exists in the user's access tokens list. // validateAccessToken checks if the token exists in the user's access tokens list.
// This enables token revocation: deleted tokens are removed from the list. // This enables token revocation: deleted tokens are removed from the list.
func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting_AccessToken) bool { func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
@ -215,15 +197,12 @@ type AuthResult struct {
// It tries session cookie first, then JWT token. // It tries session cookie first, then JWT token.
// Returns nil if no valid credentials are provided. // Returns nil if no valid credentials are provided.
// On successful session auth, it also updates the session sliding expiration. // On successful session auth, it also updates the session sliding expiration.
func (a *Authenticator) Authenticate(ctx context.Context, sessionCookie, authHeader string) *AuthResult { func (a *Authenticator) Authenticate(ctx context.Context, sessionID, authHeader string) *AuthResult {
// Try session cookie authentication first // Try session cookie authentication first
if sessionCookie != "" { if sessionID != "" {
user, err := a.AuthenticateBySession(ctx, sessionCookie) user, err := a.AuthenticateBySession(ctx, sessionID)
if err == nil && user != nil { if err == nil && user != nil {
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookie) a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
if parseErr == nil && sessionID != "" {
a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
}
return &AuthResult{User: user, SessionID: sessionID} return &AuthResult{User: user, SessionID: sessionID}
} }
} }

View File

@ -11,11 +11,9 @@ package auth
import ( import (
"fmt" "fmt"
"strings"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util" "github.com/usememos/memos/internal/util"
) )
@ -40,7 +38,7 @@ const (
SessionSlidingDuration = 14 * 24 * time.Hour SessionSlidingDuration = 14 * 24 * time.Hour
// SessionCookieName is the HTTP cookie name used to store session information. // SessionCookieName is the HTTP cookie name used to store session information.
// Cookie value format: {userID}-{sessionID}. // Cookie value is the session ID (UUID).
SessionCookieName = "user_session" SessionCookieName = "user_session"
) )
@ -108,37 +106,7 @@ func generateToken(username string, userID int32, audience string, expirationTim
// //
// Uses UUID v4 (random) for high entropy and uniqueness. // Uses UUID v4 (random) for high entropy and uniqueness.
// Session IDs are stored in user settings and used to identify browser sessions. // Session IDs are stored in user settings and used to identify browser sessions.
// The session ID is stored directly in the cookie as the cookie value.
func GenerateSessionID() string { func GenerateSessionID() string {
return util.GenUUID() return util.GenUUID()
} }
// BuildSessionCookieValue creates the session cookie value.
//
// Format: {userID}-{sessionID}
// Example: "123-550e8400-e29b-41d4-a716-446655440000"
//
// This format allows quick extraction of both user ID and session ID
// from the cookie without database lookup during authentication.
func BuildSessionCookieValue(userID int32, sessionID string) string {
return fmt.Sprintf("%d-%s", userID, sessionID)
}
// ParseSessionCookieValue extracts user ID and session ID from cookie value.
//
// Input format: "{userID}-{sessionID}"
// Returns: (userID, sessionID, error)
//
// Example: "123-550e8400-..." → (123, "550e8400-...", nil).
func ParseSessionCookieValue(cookieValue string) (int32, string, error) {
parts := strings.SplitN(cookieValue, "-", 2)
if len(parts) != 2 {
return 0, "", errors.New("invalid session cookie format")
}
userID, err := util.ConvertStringToInt32(parts[0])
if err != nil {
return 0, "", errors.Errorf("invalid user ID in session cookie: %v", err)
}
return userID, parts[1], nil
}

View File

@ -230,9 +230,8 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim
slog.Error("failed to track user session", "error", err) slog.Error("failed to track user session", "error", err)
} }
// Set session cookie for web use (format: userID-sessionID) // Set session cookie for web use
sessionCookieValue := auth.BuildSessionCookieValue(user.ID, sessionID) sessionCookie, err := s.buildSessionCookie(ctx, sessionID, expireTime)
sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err) return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
} }

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"strings" "strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -54,3 +56,42 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
// Query user_setting that contains this sessionID in the sessions array
// Use JSON_SEARCH to check if sessionID exists in the array
query := `
SELECT
user_id,
value
FROM user_setting
WHERE ` + "`key`" + ` = 'SESSIONS'
AND JSON_SEARCH(value, 'one', ?, NULL, '$.sessions[*].sessionId') IS NOT NULL
`
var userID int32
var sessionsJSON string
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
if err != nil {
return nil, err
}
// Parse the entire sessions list using protobuf unmarshaler
sessionsUserSetting := &storepb.SessionsUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
return nil, err
}
// Find the specific session by ID
for _, session := range sessionsUserSetting.Sessions {
if session.SessionId == sessionID {
return &store.UserSessionQueryResult{
UserID: userID,
Session: session,
}, nil
}
}
return nil, errors.New("session not found")
}

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"strings" "strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -67,3 +69,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
// Query user_setting that contains this sessionID in the sessions array
// Use EXISTS with jsonb_array_elements to check array membership
query := `
SELECT
user_setting.user_id,
user_setting.value
FROM user_setting
WHERE user_setting.key = 'SESSIONS'
AND EXISTS (
SELECT 1
FROM jsonb_array_elements(user_setting.value::jsonb->'sessions') AS session
WHERE session->>'sessionId' = $1
)
`
var userID int32
var sessionsJSON string
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
if err != nil {
return nil, err
}
// Parse the entire sessions list using protobuf unmarshaler
sessionsUserSetting := &storepb.SessionsUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
return nil, err
}
// Find the specific session by ID
for _, session := range sessionsUserSetting.Sessions {
if session.SessionId == sessionID {
return &store.UserSessionQueryResult{
UserID: userID,
Session: session,
}, nil
}
}
return nil, errors.New("session not found")
}

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"strings" "strings"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -66,3 +68,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
// Query user_setting that contains this sessionID in the sessions array
// Use EXISTS with json_each to properly check array membership
query := `
SELECT
user_setting.user_id,
user_setting.value
FROM user_setting
WHERE user_setting.key = 'SESSIONS'
AND EXISTS (
SELECT 1
FROM json_each(json_extract(user_setting.value, '$.sessions')) AS session
WHERE json_extract(session.value, '$.sessionId') = ?
)
`
var userID int32
var sessionsJSON string
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
if err != nil {
return nil, err
}
// Parse the entire sessions list using protobuf unmarshaler
sessionsUserSetting := &storepb.SessionsUserSetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
return nil, err
}
// Find the specific session by ID
for _, session := range sessionsUserSetting.Sessions {
if session.SessionId == sessionID {
return &store.UserSessionQueryResult{
UserID: userID,
Session: session,
}, nil
}
}
return nil, errors.New("session not found")
}

View File

@ -48,6 +48,7 @@ type Driver interface {
// UserSetting model related methods. // UserSetting model related methods.
UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error)
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error)
GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error)
// IdentityProvider model related methods. // IdentityProvider model related methods.
CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error)

View File

@ -21,6 +21,12 @@ type FindUserSetting struct {
Key storepb.UserSetting_Key Key storepb.UserSetting_Key
} }
// UserSessionQueryResult contains the result of querying a single session by ID.
type UserSessionQueryResult struct {
UserID int32
Session *storepb.SessionsUserSetting_Session
}
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
userSettingRaw, err := convertUserSettingToRaw(upsert) userSettingRaw, err := convertUserSettingToRaw(upsert)
if err != nil { if err != nil {
@ -241,6 +247,12 @@ func (s *Store) UpdateUserSessionLastAccessed(ctx context.Context, userID int32,
return err return err
} }
// GetUserSessionByID returns the session details for the given session ID.
// Uses database-specific JSON queries for efficient lookup without loading all sessions.
func (s *Store) GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error) {
return s.driver.GetUserSessionByID(ctx, sessionID)
}
// GetUserWebhooks returns the webhooks of the user. // GetUserWebhooks returns the webhooks of the user.
func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) { func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) {
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{ userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{