mirror of https://github.com/usememos/memos.git
fix: auth context
This commit is contained in:
parent
45df653f37
commit
6e4d1d9100
|
|
@ -23,9 +23,8 @@ import (
|
|||
type ContextKey int
|
||||
|
||||
const (
|
||||
// The key name used to store username in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
usernameContextKey ContextKey = iota
|
||||
// The key name used to store user's ID in the context (for user-based auth).
|
||||
userIDContextKey ContextKey = iota
|
||||
// The key name used to store session ID in the context (for session-based auth).
|
||||
sessionIDContextKey
|
||||
// The key name used to store access token in the context (for token-based auth).
|
||||
|
|
@ -48,11 +47,6 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep
|
|||
|
||||
// AuthenticationInterceptor is the unary interceptor for gRPC API.
|
||||
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
// Check if this method is in the allowlist first
|
||||
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
}
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
|
||||
|
|
@ -65,21 +59,25 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
|
|||
}
|
||||
|
||||
// Authenticate using access token (which also validates sessions when it's from cookie)
|
||||
username, user, err := in.authenticateByAccessToken(ctx, accessToken)
|
||||
user, err := in.authenticateByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
// Check if this method is in the allowlist first
|
||||
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check user status
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, errors.Errorf("user %q is archived", username)
|
||||
return nil, errors.Errorf("user %q is archived", user.Username)
|
||||
}
|
||||
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
|
||||
return nil, errors.Errorf("user %q is not admin", username)
|
||||
return nil, errors.Errorf("user %q is not admin", user.Username)
|
||||
}
|
||||
|
||||
// Set context values
|
||||
ctx = context.WithValue(ctx, usernameContextKey, username)
|
||||
ctx = context.WithValue(ctx, userIDContextKey, user.ID)
|
||||
|
||||
// Determine if this came from cookie (session) or header (API token)
|
||||
if _, headerErr := getAccessTokenFromMetadata(md); headerErr != nil {
|
||||
|
|
@ -96,9 +94,9 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
|
|||
}
|
||||
|
||||
// authenticateByAccessToken authenticates a user using access token from Authorization header or cookie.
|
||||
func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (string, *store.User, error) {
|
||||
func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (*store.User, error) {
|
||||
if accessToken == "" {
|
||||
return "", nil, status.Errorf(codes.Unauthenticated, "access token not found")
|
||||
return nil, status.Errorf(codes.Unauthenticated, "access token not found")
|
||||
}
|
||||
claims := &ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
|
|
@ -113,33 +111,33 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
|
|||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
|
||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
|
||||
}
|
||||
|
||||
// We either have a valid access token or we will attempt to generate new access token.
|
||||
userID, err := util.ConvertStringToInt32(claims.Subject)
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "malformed ID in the token")
|
||||
return nil, errors.Wrap(err, "malformed ID in the token")
|
||||
}
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrap(err, "failed to get user")
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return "", nil, errors.Errorf("user %q not exists", userID)
|
||||
return nil, errors.Errorf("user %q not exists", userID)
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return "", nil, errors.Errorf("user %q is archived", userID)
|
||||
return nil, errors.Errorf("user %q is archived", userID)
|
||||
}
|
||||
|
||||
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return "", nil, errors.Wrapf(err, "failed to get user access tokens")
|
||||
return nil, errors.Wrapf(err, "failed to get user access tokens")
|
||||
}
|
||||
if !validateAccessToken(accessToken, accessTokens) {
|
||||
return "", nil, status.Errorf(codes.Unauthenticated, "invalid access token")
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
|
||||
}
|
||||
|
||||
// For tokens that might be used as session IDs (from cookies), also validate session existence
|
||||
|
|
@ -148,7 +146,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
|
|||
validateUserSession(accessToken, sessions) // Result doesn't matter for API tokens
|
||||
}
|
||||
|
||||
return user.Username, user, nil
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// updateSessionLastAccessed updates the last accessed time for a user session.
|
||||
|
|
@ -204,9 +202,6 @@ func getTokenFromMetadata(md metadata.MD) (string, error) {
|
|||
accessToken = v.Value
|
||||
}
|
||||
}
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access token not found")
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ var authenticationAllowlistMethods = map[string]bool{
|
|||
"/memos.api.v1.IdentityProviderService/GetIdentityProvider": true,
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
|
||||
"/memos.api.v1.AuthService/CreateSession": true,
|
||||
"/memos.api.v1.AuthService/GetCurrentSession": true,
|
||||
"/memos.api.v1.AuthService/SignUp": true,
|
||||
"/memos.api.v1.UserService/GetUser": true,
|
||||
"/memos.api.v1.UserService/GetUserAvatar": true,
|
||||
|
|
|
|||
|
|
@ -331,16 +331,19 @@ func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken str
|
|||
}
|
||||
|
||||
func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
|
||||
username, ok := ctx.Value(usernameContextKey).(string)
|
||||
userID, ok := ctx.Value(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &username,
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.Errorf("user %d not found", userID)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ func TestCreateIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
ctx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create OAuth2 identity provider
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
|
|
@ -71,7 +71,7 @@ func TestCreateIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
ctx := ts.CreateUserContext(ctx, regularUser.Username)
|
||||
ctx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
|
|
@ -125,7 +125,7 @@ func TestListIdentityProviders(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a couple of identity providers
|
||||
createReq1 := &v1pb.CreateIdentityProviderRequest{
|
||||
|
|
@ -199,7 +199,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
|
|
@ -284,7 +284,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
|
|
@ -398,7 +398,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create identity provider
|
||||
createReq := &v1pb.CreateIdentityProviderRequest{
|
||||
|
|
@ -464,7 +464,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "identityProviders/999",
|
||||
|
|
@ -488,7 +488,7 @@ func TestIdentityProviderPermissions(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
req := &v1pb.CreateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ func TestListInboxes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List inboxes (should be empty initially)
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
|
|
@ -64,7 +64,7 @@ func TestListInboxes(t *testing.T) {
|
|||
}
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List inboxes with page size limit
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
|
|
@ -90,7 +90,7 @@ func TestListInboxes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to list user2's inboxes
|
||||
userCtx := ts.CreateUserContext(ctx, user1.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
|
|
@ -124,7 +124,7 @@ func TestListInboxes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set host user context and try to list regular user's inboxes
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: fmt.Sprintf("users/%d", regularUser.ID),
|
||||
|
|
@ -145,7 +145,7 @@ func TestListInboxes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.ListInboxesRequest{
|
||||
Parent: "invalid-parent-format",
|
||||
|
|
@ -194,7 +194,7 @@ func TestUpdateInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Update inbox status
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
|
|
@ -236,7 +236,7 @@ func TestUpdateInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to update user2's inbox
|
||||
userCtx := ts.CreateUserContext(ctx, user1.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
|
|
@ -262,7 +262,7 @@ func TestUpdateInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
|
|
@ -285,7 +285,7 @@ func TestUpdateInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
|
|
@ -311,7 +311,7 @@ func TestUpdateInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
|
|
@ -351,7 +351,7 @@ func TestUpdateInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateInboxRequest{
|
||||
Inbox: &v1pb.Inbox{
|
||||
|
|
@ -393,7 +393,7 @@ func TestDeleteInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Delete inbox
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
|
|
@ -434,7 +434,7 @@ func TestDeleteInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to delete user2's inbox
|
||||
userCtx := ts.CreateUserContext(ctx, user1.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
|
||||
|
|
@ -454,7 +454,7 @@ func TestDeleteInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: "invalid-inbox-name",
|
||||
|
|
@ -474,7 +474,7 @@ func TestDeleteInbox(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteInboxRequest{
|
||||
Name: "inboxes/99999", // Non-existent inbox
|
||||
|
|
@ -512,7 +512,7 @@ func TestInboxCRUDComplete(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// 1. List inboxes - should have 1
|
||||
listReq := &v1pb.ListInboxesRequest{
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ func TestListShortcuts(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// List shortcuts (should be empty initially)
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
|
|
@ -47,7 +47,7 @@ func TestListShortcuts(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to list user2's shortcuts
|
||||
userCtx := ts.CreateUserContext(ctx, user1.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
|
|
@ -67,7 +67,7 @@ func TestListShortcuts(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.ListShortcutsRequest{
|
||||
Parent: "invalid-parent-format",
|
||||
|
|
@ -104,7 +104,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// First create a shortcut
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
|
|
@ -142,7 +142,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.Username)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
|
|
@ -155,7 +155,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Try to get shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.Username)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
getReq := &v1pb.GetShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
|
@ -174,7 +174,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: "invalid-shortcut-name",
|
||||
|
|
@ -194,7 +194,7 @@ func TestGetShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.GetShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
|
|
@ -218,7 +218,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
|
|
@ -257,7 +257,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user1 context but try to create shortcut for user2
|
||||
userCtx := ts.CreateUserContext(ctx, user1.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user1.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user2.ID),
|
||||
|
|
@ -281,7 +281,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: "invalid-parent",
|
||||
|
|
@ -305,7 +305,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
|
|
@ -329,7 +329,7 @@ func TestCreateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user.ID),
|
||||
|
|
@ -356,7 +356,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
|
|
@ -401,7 +401,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.Username)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
|
|
@ -414,7 +414,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Try to update shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.Username)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
updateReq := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
Name: created.Name,
|
||||
|
|
@ -438,7 +438,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
// Create a user and context for authentication
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.UpdateShortcutRequest{
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
|
|
@ -480,7 +480,7 @@ func TestUpdateShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
|
|
@ -523,7 +523,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Create a shortcut first
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
|
|
@ -575,7 +575,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Create shortcut as user1
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.Username)
|
||||
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
|
||||
createReq := &v1pb.CreateShortcutRequest{
|
||||
Parent: fmt.Sprintf("users/%d", user1.ID),
|
||||
Shortcut: &v1pb.Shortcut{
|
||||
|
|
@ -588,7 +588,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Try to delete shortcut as user2
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.Username)
|
||||
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
|
||||
deleteReq := &v1pb.DeleteShortcutRequest{
|
||||
Name: created.Name,
|
||||
}
|
||||
|
|
@ -620,7 +620,7 @@ func TestDeleteShortcut(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
req := &v1pb.DeleteShortcutRequest{
|
||||
Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent",
|
||||
|
|
@ -644,7 +644,7 @@ func TestShortcutFiltering(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various valid filter formats
|
||||
validFilters := []string{
|
||||
|
|
@ -681,7 +681,7 @@ func TestShortcutFiltering(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// Test various invalid filter formats
|
||||
invalidFilters := []string{
|
||||
|
|
@ -723,7 +723,7 @@ func TestShortcutCRUDComplete(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, user.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, user.ID)
|
||||
|
||||
// 1. Create multiple shortcuts
|
||||
shortcut1Req := &v1pb.CreateShortcutRequest{
|
||||
|
|
|
|||
|
|
@ -74,8 +74,8 @@ func (ts *TestService) CreateRegularUser(ctx context.Context, username string) (
|
|||
})
|
||||
}
|
||||
|
||||
// CreateUserContext creates a context with the given username for authentication.
|
||||
func (*TestService) CreateUserContext(ctx context.Context, username string) context.Context {
|
||||
// CreateUserContext creates a context with the given user's ID for authentication.
|
||||
func (*TestService) CreateUserContext(ctx context.Context, userID int32) context.Context {
|
||||
// Use the real context key from the parent package
|
||||
return apiv1.CreateTestUserContext(ctx, username)
|
||||
return apiv1.CreateTestUserContext(ctx, userID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ func TestCreateWebhook(t *testing.T) {
|
|||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -72,7 +72,7 @@ func TestCreateWebhook(t *testing.T) {
|
|||
regularUser, err := ts.CreateRegularUser(ctx, "user1")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
|
||||
|
||||
// Try to create webhook as regular user
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -98,7 +98,7 @@ func TestCreateWebhook(t *testing.T) {
|
|||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Try to create webhook with missing URL
|
||||
req := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -127,7 +127,7 @@ func TestListWebhooks(t *testing.T) {
|
|||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// List webhooks
|
||||
req := &v1pb.ListWebhooksRequest{}
|
||||
|
|
@ -147,7 +147,7 @@ func TestListWebhooks(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -196,7 +196,7 @@ func TestGetWebhook(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -230,7 +230,7 @@ func TestGetWebhook(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Try to get webhook with invalid name
|
||||
req := &v1pb.GetWebhookRequest{
|
||||
|
|
@ -250,7 +250,7 @@ func TestGetWebhook(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Try to get non-existent webhook
|
||||
req := &v1pb.GetWebhookRequest{
|
||||
|
|
@ -275,7 +275,7 @@ func TestUpdateWebhook(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -337,7 +337,7 @@ func TestDeleteWebhook(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Create a webhook
|
||||
createReq := &v1pb.CreateWebhookRequest{
|
||||
|
|
@ -393,7 +393,7 @@ func TestDeleteWebhook(t *testing.T) {
|
|||
// Create host user and authenticate
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Try to delete non-existent webhook
|
||||
req := &v1pb.DeleteWebhookRequest{
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ func TestGetWorkspaceSetting(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Add user to context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.Username)
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
// Call GetWorkspaceSetting for storage setting
|
||||
req := &v1pb.GetWorkspaceSettingRequest{
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// CreateTestUserContext creates a context with username for testing purposes.
|
||||
// CreateTestUserContext creates a context with user's ID for testing purposes.
|
||||
// This function is only intended for use in tests.
|
||||
func CreateTestUserContext(ctx context.Context, username string) context.Context {
|
||||
return context.WithValue(ctx, usernameContextKey, username)
|
||||
func CreateTestUserContext(ctx context.Context, userID int32) context.Context {
|
||||
return context.WithValue(ctx, userIDContextKey, userID)
|
||||
}
|
||||
|
||||
// CreateTestUserContextWithUser creates a context and ensures the user exists for testing.
|
||||
// This function is only intended for use in tests.
|
||||
func CreateTestUserContextWithUser(ctx context.Context, _ *APIV1Service, user *store.User) context.Context {
|
||||
return context.WithValue(ctx, usernameContextKey, user.Username)
|
||||
return context.WithValue(ctx, userIDContextKey, user.ID)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue