diff --git a/server/router/api/v1/acl.go b/server/router/api/v1/acl.go index fdb400297..99c24c979 100644 --- a/server/router/api/v1/acl.go +++ b/server/router/api/v1/acl.go @@ -23,9 +23,9 @@ import ( type ContextKey int const ( - // userIDContextKey stores the authenticated user's ID in the context. + // UserIDContextKey stores the authenticated user's ID in the context. // Set for both session-based and token-based authentication. - userIDContextKey ContextKey = iota + UserIDContextKey ContextKey = iota // sessionIDContextKey stores the session ID in the context. // Only set for session-based authentication (cookie auth). @@ -59,7 +59,7 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep // 4. Reject: Return 401 Unauthenticated if none of the above succeed // // On successful authentication, sets context values: -// - userIDContextKey: The authenticated user's ID (always set) +// - UserIDContextKey: The authenticated user's ID (always set) // - sessionIDContextKey: Session ID (only for cookie auth) // - accessTokenContextKey: JWT token (only for Bearer token auth). func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { @@ -115,7 +115,7 @@ func (in *GRPCAuthInterceptor) handleAuthenticatedRequest(ctx context.Context, r } // Set context values - ctx = context.WithValue(ctx, userIDContextKey, user.ID) + ctx = context.WithValue(ctx, UserIDContextKey, user.ID) if sessionID != "" { // Session-based authentication diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 74ac00299..461bf8b72 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -325,7 +325,7 @@ func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue } func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) { - userID, ok := ctx.Value(userIDContextKey).(int32) + userID, ok := ctx.Value(UserIDContextKey).(int32) if !ok { return nil, nil } diff --git a/server/router/api/v1/test/test_helper.go b/server/router/api/v1/test/test_helper.go index eb9ef93b2..63883e4f7 100644 --- a/server/router/api/v1/test/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -82,5 +82,5 @@ func (ts *TestService) CreateRegularUser(ctx context.Context, username string) ( // CreateUserContext creates a context with the given user's ID for authentication. func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context { // Use the real context key from the parent package - return apiv1.CreateTestUserContext(ctx, userID) + return context.WithValue(ctx, apiv1.UserIDContextKey, userID) } diff --git a/server/router/api/v1/test_auth.go b/server/router/api/v1/test_auth.go deleted file mode 100644 index f2f09bd1e..000000000 --- a/server/router/api/v1/test_auth.go +++ /dev/null @@ -1,19 +0,0 @@ -package v1 - -import ( - "context" - - "github.com/usememos/memos/store" -) - -// CreateTestUserContext creates a context with user's ID for testing purposes. -// This function is only intended for use in tests. -func CreateTestUserContext(ctx context.Context, userID int32) context.Context { - return context.WithValue(ctx, userIDContextKey, userID) -} - -// CreateTestUserContextWithUser creates a context and ensures the user exists for testing. -// This function is only intended for use in tests. -func CreateTestUserContextWithUser(ctx context.Context, _ *APIV1Service, user *store.User) context.Context { - return context.WithValue(ctx, userIDContextKey, user.ID) -}