diff --git a/server/auth/authenticator.go b/server/auth/authenticator.go index 3876406a0..d66961b1a 100644 --- a/server/auth/authenticator.go +++ b/server/auth/authenticator.go @@ -130,6 +130,40 @@ type AuthResult struct { AccessToken string // Non-empty if authenticated via JWT } +// AuthenticateToUser resolves the current request to a *store.User, checking the +// Authorization header first (access token or PAT), then falling back to the +// refresh token cookie. Returns (nil, nil) when no credentials are present. +func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cookieHeader string) (*store.User, error) { + // Try Bearer token first. + if authHeader != "" { + token := ExtractBearerToken(authHeader) + if token != "" { + if !strings.HasPrefix(token, PersonalAccessTokenPrefix) { + claims, err := a.AuthenticateByAccessTokenV2(token) + if err == nil && claims != nil { + return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) + } + } else { + user, _, err := a.AuthenticateByPAT(ctx, token) + if err == nil { + return user, nil + } + } + } + } + + // Fallback: refresh token cookie. + if cookieHeader != "" { + refreshToken := ExtractRefreshTokenFromCookie(cookieHeader) + if refreshToken != "" { + user, _, err := a.AuthenticateByRefreshToken(ctx, refreshToken) + return user, err + } + } + + return nil, nil +} + // Authenticate tries to authenticate using the provided credentials. // Priority: 1. Access Token V2, 2. PAT // Returns nil if no valid credentials are provided. diff --git a/server/auth/context.go b/server/auth/context.go index cdeba0df5..a4010bc3b 100644 --- a/server/auth/context.go +++ b/server/auth/context.go @@ -81,3 +81,19 @@ func GetUserClaims(ctx context.Context) *UserClaims { func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context { return context.WithValue(ctx, UserClaimsContextKey, claims) } + +// ApplyToContext sets the authenticated identity from an AuthResult into the context. +// This is the canonical way to propagate auth state after a successful Authenticate call. +// Safe to call with a nil result (no-op). +func ApplyToContext(ctx context.Context, result *AuthResult) context.Context { + if result == nil { + return ctx + } + if result.Claims != nil { + ctx = SetUserClaimsInContext(ctx, result.Claims) + ctx = context.WithValue(ctx, UserIDContextKey, result.Claims.UserID) + } else if result.User != nil { + ctx = SetUserInContext(ctx, result.User, result.AccessToken) + } + return ctx +} diff --git a/server/router/api/v1/connect_interceptors.go b/server/router/api/v1/connect_interceptors.go index dab7150d9..9ea26f3b0 100644 --- a/server/router/api/v1/connect_interceptors.go +++ b/server/router/api/v1/connect_interceptors.go @@ -222,17 +222,7 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required")) } - // Set context based on auth result - if result != nil { - if result.Claims != nil { - // Access Token V2 - stateless, use claims - ctx = auth.SetUserClaimsInContext(ctx, result.Claims) - ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID) - } else if result.User != nil { - // PAT - have full user - ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken) - } - } + ctx = auth.ApplyToContext(ctx, result) return next(ctx, req) } diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 2fa17ec90..cab335c81 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -73,16 +73,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech return } - // Set context based on auth result (may be nil for public endpoints) + // Apply auth result to context (no-op when result is nil for public endpoints) if result != nil { - if result.Claims != nil { - // Access Token V2 - stateless, use claims - ctx = auth.SetUserClaimsInContext(ctx, result.Claims) - ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID) - } else if result.User != nil { - // PAT - have full user - ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken) - } + ctx = auth.ApplyToContext(ctx, result) r = r.WithContext(ctx) } diff --git a/server/router/fileserver/fileserver.go b/server/router/fileserver/fileserver.go index 3fe30b1de..5fb86aabd 100644 --- a/server/router/fileserver/fileserver.go +++ b/server/router/fileserver/fileserver.go @@ -515,58 +515,9 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *ec // getCurrentUser retrieves the current authenticated user from the request. // Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie. func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) { - // Try Bearer token authentication. - if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" { - if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil { - return user, nil - } - } - - // Fallback: Try refresh token cookie. - if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" { - if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil { - return user, nil - } - } - - return nil, nil -} - -// authenticateByBearerToken authenticates using Authorization header. -func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) { - token := auth.ExtractBearerToken(authHeader) - if token == "" { - return nil, nil - } - - // Try Access Token V2 (stateless JWT). - if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { - claims, err := s.authenticator.AuthenticateByAccessTokenV2(token) - if err == nil && claims != nil { - return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) - } - } - - // Try Personal Access Token (stateful). - if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { - user, _, err := s.authenticator.AuthenticateByPAT(ctx, token) - if err == nil { - return user, nil - } - } - - return nil, nil -} - -// authenticateByRefreshToken authenticates using refresh token cookie. -func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) { - refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader) - if refreshToken == "" { - return nil, nil - } - - user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken) - return user, err + authHeader := c.Request().Header.Get(echo.HeaderAuthorization) + cookieHeader := c.Request().Header.Get("Cookie") + return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader) } // getUserByIdentifier finds a user by either ID or username. diff --git a/server/router/mcp/auth_middleware.go b/server/router/mcp/auth_middleware.go index c0ee94b5f..02e9a2f2f 100644 --- a/server/router/mcp/auth_middleware.go +++ b/server/router/mcp/auth_middleware.go @@ -2,8 +2,6 @@ package mcp import ( "net/http" - "strings" - "time" "github.com/labstack/echo/v5" @@ -11,23 +9,21 @@ import ( "github.com/usememos/memos/store" ) -func newAuthMiddleware(s *store.Store) echo.MiddlewareFunc { +func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc { + authenticator := auth.NewAuthenticator(s, secret) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization")) - if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { + if token == "" { return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"}) } - result, err := s.GetUserByPATHash(c.Request().Context(), auth.HashPersonalAccessToken(token)) - if err != nil || result == nil { - return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"}) - } - if result.PAT.ExpiresAt != nil && result.PAT.ExpiresAt.AsTime().Before(time.Now()) { + user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token) + if err != nil || user == nil { return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"}) } - ctx := auth.SetUserInContext(c.Request().Context(), result.User, result.PAT.GetTokenId()) + ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId()) c.SetRequest(c.Request().WithContext(ctx)) return next(c) } diff --git a/server/router/mcp/mcp.go b/server/router/mcp/mcp.go index f8bd114e9..f6c42f940 100644 --- a/server/router/mcp/mcp.go +++ b/server/router/mcp/mcp.go @@ -9,11 +9,12 @@ import ( ) type MCPService struct { - store *store.Store + store *store.Store + secret string } -func NewMCPService(store *store.Store) *MCPService { - return &MCPService{store: store} +func NewMCPService(store *store.Store, secret string) *MCPService { + return &MCPService{store: store, secret: secret} } func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { @@ -26,6 +27,6 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{"*"}, })) - mcpGroup.Use(newAuthMiddleware(s.store)) + mcpGroup.Use(newAuthMiddleware(s.store, s.secret)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) } diff --git a/server/server.go b/server/server.go index 6d8ca7226..629ed01b8 100644 --- a/server/server.go +++ b/server/server.go @@ -81,7 +81,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } // Register MCP server. - mcpService := mcprouter.NewMCPService(s.Store) + mcpService := mcprouter.NewMCPService(s.Store, s.Secret) mcpService.RegisterRoutes(echoServer) return s, nil