refactor: consolidate duplicated auth logic into auth package

Add ApplyToContext and AuthenticateToUser helpers to the auth package,
then remove the duplicated auth code spread across the MCP middleware,
file server, Connect interceptor, and gRPC-Gateway middleware.

- auth.ApplyToContext: single place to set claims/user into context after Authenticate()
- auth.AuthenticateToUser: resolves any credential (bearer token or refresh cookie) to a *store.User
- MCP middleware: replaced manual PAT DB lookup + expiry check with Authenticator.AuthenticateByPAT
- File server: replaced authenticateByBearerToken/authenticateByRefreshToken with AuthenticateToUser
- Connect interceptor + Gateway middleware: replaced duplicated context-setting block with ApplyToContext
- MCPService now accepts secret to construct its own Authenticator
This commit is contained in:
Steven 2026-02-24 23:08:16 +08:00
parent 47d9414702
commit 26d10212c6
8 changed files with 68 additions and 87 deletions

View File

@ -130,6 +130,40 @@ type AuthResult struct {
AccessToken string // Non-empty if authenticated via JWT
}
// AuthenticateToUser resolves the current request to a *store.User, checking the
// Authorization header first (access token or PAT), then falling back to the
// refresh token cookie. Returns (nil, nil) when no credentials are present.
func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cookieHeader string) (*store.User, error) {
// Try Bearer token first.
if authHeader != "" {
token := ExtractBearerToken(authHeader)
if token != "" {
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
claims, err := a.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
}
} else {
user, _, err := a.AuthenticateByPAT(ctx, token)
if err == nil {
return user, nil
}
}
}
}
// Fallback: refresh token cookie.
if cookieHeader != "" {
refreshToken := ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken != "" {
user, _, err := a.AuthenticateByRefreshToken(ctx, refreshToken)
return user, err
}
}
return nil, nil
}
// Authenticate tries to authenticate using the provided credentials.
// Priority: 1. Access Token V2, 2. PAT
// Returns nil if no valid credentials are provided.

View File

@ -81,3 +81,19 @@ func GetUserClaims(ctx context.Context) *UserClaims {
func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context {
return context.WithValue(ctx, UserClaimsContextKey, claims)
}
// ApplyToContext sets the authenticated identity from an AuthResult into the context.
// This is the canonical way to propagate auth state after a successful Authenticate call.
// Safe to call with a nil result (no-op).
func ApplyToContext(ctx context.Context, result *AuthResult) context.Context {
if result == nil {
return ctx
}
if result.Claims != nil {
ctx = SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
ctx = SetUserInContext(ctx, result.User, result.AccessToken)
}
return ctx
}

View File

@ -222,17 +222,7 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
}
// Set context based on auth result
if result != nil {
if result.Claims != nil {
// Access Token V2 - stateless, use claims
ctx = auth.SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
// PAT - have full user
ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken)
}
}
ctx = auth.ApplyToContext(ctx, result)
return next(ctx, req)
}

View File

@ -73,16 +73,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
return
}
// Set context based on auth result (may be nil for public endpoints)
// Apply auth result to context (no-op when result is nil for public endpoints)
if result != nil {
if result.Claims != nil {
// Access Token V2 - stateless, use claims
ctx = auth.SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
// PAT - have full user
ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken)
}
ctx = auth.ApplyToContext(ctx, result)
r = r.WithContext(ctx)
}

View File

@ -515,58 +515,9 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *ec
// getCurrentUser retrieves the current authenticated user from the request.
// Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) {
// Try Bearer token authentication.
if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" {
if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil {
return user, nil
}
}
// Fallback: Try refresh token cookie.
if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" {
if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil {
return user, nil
}
}
return nil, nil
}
// authenticateByBearerToken authenticates using Authorization header.
func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) {
token := auth.ExtractBearerToken(authHeader)
if token == "" {
return nil, nil
}
// Try Access Token V2 (stateless JWT).
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
claims, err := s.authenticator.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
}
}
// Try Personal Access Token (stateful).
if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
user, _, err := s.authenticator.AuthenticateByPAT(ctx, token)
if err == nil {
return user, nil
}
}
return nil, nil
}
// authenticateByRefreshToken authenticates using refresh token cookie.
func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) {
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken == "" {
return nil, nil
}
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
return user, err
authHeader := c.Request().Header.Get(echo.HeaderAuthorization)
cookieHeader := c.Request().Header.Get("Cookie")
return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
}
// getUserByIdentifier finds a user by either ID or username.

View File

@ -2,8 +2,6 @@ package mcp
import (
"net/http"
"strings"
"time"
"github.com/labstack/echo/v5"
@ -11,23 +9,21 @@ import (
"github.com/usememos/memos/store"
)
func newAuthMiddleware(s *store.Store) echo.MiddlewareFunc {
func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc {
authenticator := auth.NewAuthenticator(s, secret)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization"))
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
if token == "" {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"})
}
result, err := s.GetUserByPATHash(c.Request().Context(), auth.HashPersonalAccessToken(token))
if err != nil || result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
}
if result.PAT.ExpiresAt != nil && result.PAT.ExpiresAt.AsTime().Before(time.Now()) {
user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token)
if err != nil || user == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
}
ctx := auth.SetUserInContext(c.Request().Context(), result.User, result.PAT.GetTokenId())
ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId())
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}

View File

@ -9,11 +9,12 @@ import (
)
type MCPService struct {
store *store.Store
store *store.Store
secret string
}
func NewMCPService(store *store.Store) *MCPService {
return &MCPService{store: store}
func NewMCPService(store *store.Store, secret string) *MCPService {
return &MCPService{store: store, secret: secret}
}
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
@ -26,6 +27,6 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
mcpGroup.Use(newAuthMiddleware(s.store))
mcpGroup.Use(newAuthMiddleware(s.store, s.secret))
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
}

View File

@ -81,7 +81,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
}
// Register MCP server.
mcpService := mcprouter.NewMCPService(s.Store)
mcpService := mcprouter.NewMCPService(s.Store, s.Secret)
mcpService.RegisterRoutes(echoServer)
return s, nil