mirror of https://github.com/usememos/memos.git
refactor(auth): streamline session authentication and cookie handling
This commit is contained in:
parent
87b8c2b2d2
commit
40e129b8af
|
|
@ -38,23 +38,23 @@ func NewAuthenticator(store *store.Store, secret string) *Authenticator {
|
|||
// AuthenticateBySession validates a session cookie and returns the authenticated user.
|
||||
//
|
||||
// Validation steps:
|
||||
// 1. Parse cookie value to extract userID and sessionID
|
||||
// 1. Use session ID to find the user and session details (single DB query)
|
||||
// 2. Verify user exists and is not archived
|
||||
// 3. Verify session exists in user's sessions list
|
||||
// 4. Check session hasn't expired (sliding expiration: 14 days from last access)
|
||||
// 3. Check session hasn't expired (sliding expiration: 14 days from last access)
|
||||
//
|
||||
// Returns the user if authentication succeeds, or an error describing the failure.
|
||||
func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) {
|
||||
if sessionCookieValue == "" {
|
||||
return nil, errors.New("session cookie value not found")
|
||||
func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionID string) (*store.User, error) {
|
||||
if sessionID == "" {
|
||||
return nil, errors.New("session ID not found")
|
||||
}
|
||||
|
||||
userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue)
|
||||
// Find the session and user in a single database query
|
||||
result, err := a.store.GetUserSessionByID(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid session cookie format")
|
||||
return nil, errors.Wrap(err, "session not found")
|
||||
}
|
||||
|
||||
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &result.UserID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
|
|
@ -65,13 +65,12 @@ func (a *Authenticator) AuthenticateBySession(ctx context.Context, sessionCookie
|
|||
return nil, errors.New("user is archived")
|
||||
}
|
||||
|
||||
sessions, err := a.store.GetUserSessions(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user sessions")
|
||||
}
|
||||
|
||||
if !validateSession(sessionID, sessions) {
|
||||
return nil, errors.New("invalid or expired session")
|
||||
// Validate session expiration
|
||||
if result.Session.LastAccessedTime != nil {
|
||||
expiration := result.Session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
|
||||
if expiration.Before(time.Now()) {
|
||||
return nil, errors.New("session expired")
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
|
@ -168,23 +167,6 @@ func (a *Authenticator) AuthorizeAndSetContext(ctx context.Context, procedure st
|
|||
return ctx, nil
|
||||
}
|
||||
|
||||
// validateSession checks if a session exists and is still valid.
|
||||
// Uses sliding expiration: session is valid if last accessed within SessionSlidingDuration.
|
||||
func validateSession(sessionID string, sessions []*storepb.SessionsUserSetting_Session) bool {
|
||||
for _, session := range sessions {
|
||||
if sessionID == session.SessionId {
|
||||
if session.LastAccessedTime != nil {
|
||||
expiration := session.LastAccessedTime.AsTime().Add(SessionSlidingDuration)
|
||||
if expiration.Before(time.Now()) {
|
||||
return false // Session expired
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false // Session not found
|
||||
}
|
||||
|
||||
// validateAccessToken checks if the token exists in the user's access tokens list.
|
||||
// This enables token revocation: deleted tokens are removed from the list.
|
||||
func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
|
||||
|
|
@ -215,15 +197,12 @@ type AuthResult struct {
|
|||
// It tries session cookie first, then JWT token.
|
||||
// Returns nil if no valid credentials are provided.
|
||||
// On successful session auth, it also updates the session sliding expiration.
|
||||
func (a *Authenticator) Authenticate(ctx context.Context, sessionCookie, authHeader string) *AuthResult {
|
||||
func (a *Authenticator) Authenticate(ctx context.Context, sessionID, authHeader string) *AuthResult {
|
||||
// Try session cookie authentication first
|
||||
if sessionCookie != "" {
|
||||
user, err := a.AuthenticateBySession(ctx, sessionCookie)
|
||||
if sessionID != "" {
|
||||
user, err := a.AuthenticateBySession(ctx, sessionID)
|
||||
if err == nil && user != nil {
|
||||
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookie)
|
||||
if parseErr == nil && sessionID != "" {
|
||||
a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
|
||||
}
|
||||
a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
|
||||
return &AuthResult{User: user, SessionID: sessionID}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,11 +11,9 @@ package auth
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
|
@ -40,7 +38,7 @@ const (
|
|||
SessionSlidingDuration = 14 * 24 * time.Hour
|
||||
|
||||
// SessionCookieName is the HTTP cookie name used to store session information.
|
||||
// Cookie value format: {userID}-{sessionID}.
|
||||
// Cookie value is the session ID (UUID).
|
||||
SessionCookieName = "user_session"
|
||||
)
|
||||
|
||||
|
|
@ -108,37 +106,7 @@ func generateToken(username string, userID int32, audience string, expirationTim
|
|||
//
|
||||
// Uses UUID v4 (random) for high entropy and uniqueness.
|
||||
// Session IDs are stored in user settings and used to identify browser sessions.
|
||||
// The session ID is stored directly in the cookie as the cookie value.
|
||||
func GenerateSessionID() string {
|
||||
return util.GenUUID()
|
||||
}
|
||||
|
||||
// BuildSessionCookieValue creates the session cookie value.
|
||||
//
|
||||
// Format: {userID}-{sessionID}
|
||||
// Example: "123-550e8400-e29b-41d4-a716-446655440000"
|
||||
//
|
||||
// This format allows quick extraction of both user ID and session ID
|
||||
// from the cookie without database lookup during authentication.
|
||||
func BuildSessionCookieValue(userID int32, sessionID string) string {
|
||||
return fmt.Sprintf("%d-%s", userID, sessionID)
|
||||
}
|
||||
|
||||
// ParseSessionCookieValue extracts user ID and session ID from cookie value.
|
||||
//
|
||||
// Input format: "{userID}-{sessionID}"
|
||||
// Returns: (userID, sessionID, error)
|
||||
//
|
||||
// Example: "123-550e8400-..." → (123, "550e8400-...", nil).
|
||||
func ParseSessionCookieValue(cookieValue string) (int32, string, error) {
|
||||
parts := strings.SplitN(cookieValue, "-", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, "", errors.New("invalid session cookie format")
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(parts[0])
|
||||
if err != nil {
|
||||
return 0, "", errors.Errorf("invalid user ID in session cookie: %v", err)
|
||||
}
|
||||
|
||||
return userID, parts[1], nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -230,9 +230,8 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim
|
|||
slog.Error("failed to track user session", "error", err)
|
||||
}
|
||||
|
||||
// Set session cookie for web use (format: userID-sessionID)
|
||||
sessionCookieValue := auth.BuildSessionCookieValue(user.ID, sessionID)
|
||||
sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime)
|
||||
// Set session cookie for web use
|
||||
sessionCookie, err := s.buildSessionCookie(ctx, sessionID, expireTime)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
|
@ -54,3 +56,42 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
|
|||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
|
||||
// Query user_setting that contains this sessionID in the sessions array
|
||||
// Use JSON_SEARCH to check if sessionID exists in the array
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + "`key`" + ` = 'SESSIONS'
|
||||
AND JSON_SEARCH(value, 'one', ?, NULL, '$.sessions[*].sessionId') IS NOT NULL
|
||||
`
|
||||
|
||||
var userID int32
|
||||
var sessionsJSON string
|
||||
|
||||
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the entire sessions list using protobuf unmarshaler
|
||||
sessionsUserSetting := &storepb.SessionsUserSetting{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find the specific session by ID
|
||||
for _, session := range sessionsUserSetting.Sessions {
|
||||
if session.SessionId == sessionID {
|
||||
return &store.UserSessionQueryResult{
|
||||
UserID: userID,
|
||||
Session: session,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
|
@ -67,3 +69,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
|
|||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
|
||||
// Query user_setting that contains this sessionID in the sessions array
|
||||
// Use EXISTS with jsonb_array_elements to check array membership
|
||||
query := `
|
||||
SELECT
|
||||
user_setting.user_id,
|
||||
user_setting.value
|
||||
FROM user_setting
|
||||
WHERE user_setting.key = 'SESSIONS'
|
||||
AND EXISTS (
|
||||
SELECT 1
|
||||
FROM jsonb_array_elements(user_setting.value::jsonb->'sessions') AS session
|
||||
WHERE session->>'sessionId' = $1
|
||||
)
|
||||
`
|
||||
|
||||
var userID int32
|
||||
var sessionsJSON string
|
||||
|
||||
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the entire sessions list using protobuf unmarshaler
|
||||
sessionsUserSetting := &storepb.SessionsUserSetting{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find the specific session by ID
|
||||
for _, session := range sessionsUserSetting.Sessions {
|
||||
if session.SessionId == sessionID {
|
||||
return &store.UserSessionQueryResult{
|
||||
UserID: userID,
|
||||
Session: session,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
|
@ -66,3 +68,46 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
|
|||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUserSessionByID(ctx context.Context, sessionID string) (*store.UserSessionQueryResult, error) {
|
||||
// Query user_setting that contains this sessionID in the sessions array
|
||||
// Use EXISTS with json_each to properly check array membership
|
||||
query := `
|
||||
SELECT
|
||||
user_setting.user_id,
|
||||
user_setting.value
|
||||
FROM user_setting
|
||||
WHERE user_setting.key = 'SESSIONS'
|
||||
AND EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(json_extract(user_setting.value, '$.sessions')) AS session
|
||||
WHERE json_extract(session.value, '$.sessionId') = ?
|
||||
)
|
||||
`
|
||||
|
||||
var userID int32
|
||||
var sessionsJSON string
|
||||
|
||||
err := d.db.QueryRowContext(ctx, query, sessionID).Scan(&userID, &sessionsJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the entire sessions list using protobuf unmarshaler
|
||||
sessionsUserSetting := &storepb.SessionsUserSetting{}
|
||||
if err := protojsonUnmarshaler.Unmarshal([]byte(sessionsJSON), sessionsUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find the specific session by ID
|
||||
for _, session := range sessionsUserSetting.Sessions {
|
||||
if session.SessionId == sessionID {
|
||||
return &store.UserSessionQueryResult{
|
||||
UserID: userID,
|
||||
Session: session,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ type Driver interface {
|
|||
// UserSetting model related methods.
|
||||
UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error)
|
||||
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error)
|
||||
GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error)
|
||||
|
||||
// IdentityProvider model related methods.
|
||||
CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,12 @@ type FindUserSetting struct {
|
|||
Key storepb.UserSetting_Key
|
||||
}
|
||||
|
||||
// UserSessionQueryResult contains the result of querying a single session by ID.
|
||||
type UserSessionQueryResult struct {
|
||||
UserID int32
|
||||
Session *storepb.SessionsUserSetting_Session
|
||||
}
|
||||
|
||||
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
||||
userSettingRaw, err := convertUserSettingToRaw(upsert)
|
||||
if err != nil {
|
||||
|
|
@ -241,6 +247,12 @@ func (s *Store) UpdateUserSessionLastAccessed(ctx context.Context, userID int32,
|
|||
return err
|
||||
}
|
||||
|
||||
// GetUserSessionByID returns the session details for the given session ID.
|
||||
// Uses database-specific JSON queries for efficient lookup without loading all sessions.
|
||||
func (s *Store) GetUserSessionByID(ctx context.Context, sessionID string) (*UserSessionQueryResult, error) {
|
||||
return s.driver.GetUserSessionByID(ctx, sessionID)
|
||||
}
|
||||
|
||||
// GetUserWebhooks returns the webhooks of the user.
|
||||
func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) {
|
||||
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{
|
||||
|
|
|
|||
Loading…
Reference in New Issue