From 26d10212c6e854987a235a790220cd8540cc158c Mon Sep 17 00:00:00 2001 From: Steven Date: Tue, 24 Feb 2026 23:08:16 +0800 Subject: [PATCH] refactor: consolidate duplicated auth logic into auth package Add ApplyToContext and AuthenticateToUser helpers to the auth package, then remove the duplicated auth code spread across the MCP middleware, file server, Connect interceptor, and gRPC-Gateway middleware. - auth.ApplyToContext: single place to set claims/user into context after Authenticate() - auth.AuthenticateToUser: resolves any credential (bearer token or refresh cookie) to a *store.User - MCP middleware: replaced manual PAT DB lookup + expiry check with Authenticator.AuthenticateByPAT - File server: replaced authenticateByBearerToken/authenticateByRefreshToken with AuthenticateToUser - Connect interceptor + Gateway middleware: replaced duplicated context-setting block with ApplyToContext - MCPService now accepts secret to construct its own Authenticator --- server/auth/authenticator.go | 34 ++++++++++++ server/auth/context.go | 16 ++++++ server/router/api/v1/connect_interceptors.go | 12 +---- server/router/api/v1/v1.go | 11 +--- server/router/fileserver/fileserver.go | 55 ++------------------ server/router/mcp/auth_middleware.go | 16 +++--- server/router/mcp/mcp.go | 9 ++-- server/server.go | 2 +- 8 files changed, 68 insertions(+), 87 deletions(-) diff --git a/server/auth/authenticator.go b/server/auth/authenticator.go index 3876406a0..d66961b1a 100644 --- a/server/auth/authenticator.go +++ b/server/auth/authenticator.go @@ -130,6 +130,40 @@ type AuthResult struct { AccessToken string // Non-empty if authenticated via JWT } +// AuthenticateToUser resolves the current request to a *store.User, checking the +// Authorization header first (access token or PAT), then falling back to the +// refresh token cookie. Returns (nil, nil) when no credentials are present. +func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cookieHeader string) (*store.User, error) { + // Try Bearer token first. + if authHeader != "" { + token := ExtractBearerToken(authHeader) + if token != "" { + if !strings.HasPrefix(token, PersonalAccessTokenPrefix) { + claims, err := a.AuthenticateByAccessTokenV2(token) + if err == nil && claims != nil { + return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) + } + } else { + user, _, err := a.AuthenticateByPAT(ctx, token) + if err == nil { + return user, nil + } + } + } + } + + // Fallback: refresh token cookie. + if cookieHeader != "" { + refreshToken := ExtractRefreshTokenFromCookie(cookieHeader) + if refreshToken != "" { + user, _, err := a.AuthenticateByRefreshToken(ctx, refreshToken) + return user, err + } + } + + return nil, nil +} + // Authenticate tries to authenticate using the provided credentials. // Priority: 1. Access Token V2, 2. PAT // Returns nil if no valid credentials are provided. diff --git a/server/auth/context.go b/server/auth/context.go index cdeba0df5..a4010bc3b 100644 --- a/server/auth/context.go +++ b/server/auth/context.go @@ -81,3 +81,19 @@ func GetUserClaims(ctx context.Context) *UserClaims { func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context { return context.WithValue(ctx, UserClaimsContextKey, claims) } + +// ApplyToContext sets the authenticated identity from an AuthResult into the context. +// This is the canonical way to propagate auth state after a successful Authenticate call. +// Safe to call with a nil result (no-op). +func ApplyToContext(ctx context.Context, result *AuthResult) context.Context { + if result == nil { + return ctx + } + if result.Claims != nil { + ctx = SetUserClaimsInContext(ctx, result.Claims) + ctx = context.WithValue(ctx, UserIDContextKey, result.Claims.UserID) + } else if result.User != nil { + ctx = SetUserInContext(ctx, result.User, result.AccessToken) + } + return ctx +} diff --git a/server/router/api/v1/connect_interceptors.go b/server/router/api/v1/connect_interceptors.go index dab7150d9..9ea26f3b0 100644 --- a/server/router/api/v1/connect_interceptors.go +++ b/server/router/api/v1/connect_interceptors.go @@ -222,17 +222,7 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required")) } - // Set context based on auth result - if result != nil { - if result.Claims != nil { - // Access Token V2 - stateless, use claims - ctx = auth.SetUserClaimsInContext(ctx, result.Claims) - ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID) - } else if result.User != nil { - // PAT - have full user - ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken) - } - } + ctx = auth.ApplyToContext(ctx, result) return next(ctx, req) } diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 2fa17ec90..cab335c81 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -73,16 +73,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech return } - // Set context based on auth result (may be nil for public endpoints) + // Apply auth result to context (no-op when result is nil for public endpoints) if result != nil { - if result.Claims != nil { - // Access Token V2 - stateless, use claims - ctx = auth.SetUserClaimsInContext(ctx, result.Claims) - ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID) - } else if result.User != nil { - // PAT - have full user - ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken) - } + ctx = auth.ApplyToContext(ctx, result) r = r.WithContext(ctx) } diff --git a/server/router/fileserver/fileserver.go b/server/router/fileserver/fileserver.go index 3fe30b1de..5fb86aabd 100644 --- a/server/router/fileserver/fileserver.go +++ b/server/router/fileserver/fileserver.go @@ -515,58 +515,9 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *ec // getCurrentUser retrieves the current authenticated user from the request. // Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie. func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) { - // Try Bearer token authentication. - if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" { - if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil { - return user, nil - } - } - - // Fallback: Try refresh token cookie. - if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" { - if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil { - return user, nil - } - } - - return nil, nil -} - -// authenticateByBearerToken authenticates using Authorization header. -func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) { - token := auth.ExtractBearerToken(authHeader) - if token == "" { - return nil, nil - } - - // Try Access Token V2 (stateless JWT). - if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { - claims, err := s.authenticator.AuthenticateByAccessTokenV2(token) - if err == nil && claims != nil { - return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) - } - } - - // Try Personal Access Token (stateful). - if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { - user, _, err := s.authenticator.AuthenticateByPAT(ctx, token) - if err == nil { - return user, nil - } - } - - return nil, nil -} - -// authenticateByRefreshToken authenticates using refresh token cookie. -func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) { - refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader) - if refreshToken == "" { - return nil, nil - } - - user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken) - return user, err + authHeader := c.Request().Header.Get(echo.HeaderAuthorization) + cookieHeader := c.Request().Header.Get("Cookie") + return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader) } // getUserByIdentifier finds a user by either ID or username. diff --git a/server/router/mcp/auth_middleware.go b/server/router/mcp/auth_middleware.go index c0ee94b5f..02e9a2f2f 100644 --- a/server/router/mcp/auth_middleware.go +++ b/server/router/mcp/auth_middleware.go @@ -2,8 +2,6 @@ package mcp import ( "net/http" - "strings" - "time" "github.com/labstack/echo/v5" @@ -11,23 +9,21 @@ import ( "github.com/usememos/memos/store" ) -func newAuthMiddleware(s *store.Store) echo.MiddlewareFunc { +func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc { + authenticator := auth.NewAuthenticator(s, secret) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization")) - if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { + if token == "" { return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"}) } - result, err := s.GetUserByPATHash(c.Request().Context(), auth.HashPersonalAccessToken(token)) - if err != nil || result == nil { - return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"}) - } - if result.PAT.ExpiresAt != nil && result.PAT.ExpiresAt.AsTime().Before(time.Now()) { + user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token) + if err != nil || user == nil { return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"}) } - ctx := auth.SetUserInContext(c.Request().Context(), result.User, result.PAT.GetTokenId()) + ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId()) c.SetRequest(c.Request().WithContext(ctx)) return next(c) } diff --git a/server/router/mcp/mcp.go b/server/router/mcp/mcp.go index f8bd114e9..f6c42f940 100644 --- a/server/router/mcp/mcp.go +++ b/server/router/mcp/mcp.go @@ -9,11 +9,12 @@ import ( ) type MCPService struct { - store *store.Store + store *store.Store + secret string } -func NewMCPService(store *store.Store) *MCPService { - return &MCPService{store: store} +func NewMCPService(store *store.Store, secret string) *MCPService { + return &MCPService{store: store, secret: secret} } func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { @@ -26,6 +27,6 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{"*"}, })) - mcpGroup.Use(newAuthMiddleware(s.store)) + mcpGroup.Use(newAuthMiddleware(s.store, s.secret)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) } diff --git a/server/server.go b/server/server.go index 6d8ca7226..629ed01b8 100644 --- a/server/server.go +++ b/server/server.go @@ -81,7 +81,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } // Register MCP server. - mcpService := mcprouter.NewMCPService(s.Store) + mcpService := mcprouter.NewMCPService(s.Store, s.Secret) mcpService.RegisterRoutes(echoServer) return s, nil