diff --git a/proto/api/v1/auth_service.proto b/proto/api/v1/auth_service.proto index 545a5f62d..45a3a1d68 100644 --- a/proto/api/v1/auth_service.proto +++ b/proto/api/v1/auth_service.proto @@ -64,8 +64,9 @@ message SignInRequest { // Nested message for SSO authentication credentials. message SSOCredentials { - // The ID of the SSO provider. - int32 idp_id = 1 [(google.api.field_behavior) = REQUIRED]; + // The resource name of the SSO provider. + // Format: identity-providers/{uid} + string idp_name = 1 [(google.api.field_behavior) = REQUIRED]; // The authorization code from the SSO provider. string code = 2 [(google.api.field_behavior) = REQUIRED]; diff --git a/proto/gen/api/v1/activity_service_grpc.pb.go b/proto/gen/api/v1/activity_service_grpc.pb.go index ee48c8471..27b4d7da9 100644 --- a/proto/gen/api/v1/activity_service_grpc.pb.go +++ b/proto/gen/api/v1/activity_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/activity_service.proto diff --git a/proto/gen/api/v1/attachment_service_grpc.pb.go b/proto/gen/api/v1/attachment_service_grpc.pb.go index 4a32b3c6f..16b0327a6 100644 --- a/proto/gen/api/v1/attachment_service_grpc.pb.go +++ b/proto/gen/api/v1/attachment_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/attachment_service.proto diff --git a/proto/gen/api/v1/auth_service.pb.go b/proto/gen/api/v1/auth_service.pb.go index dade2e019..0de77b478 100644 --- a/proto/gen/api/v1/auth_service.pb.go +++ b/proto/gen/api/v1/auth_service.pb.go @@ -440,8 +440,9 @@ func (x *SignInRequest_PasswordCredentials) GetPassword() string { // Nested message for SSO authentication credentials. type SignInRequest_SSOCredentials struct { state protoimpl.MessageState `protogen:"open.v1"` - // The ID of the SSO provider. - IdpId int32 `protobuf:"varint,1,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"` + // The resource name of the SSO provider. + // Format: identity-providers/{uid} + IdpName string `protobuf:"bytes,1,opt,name=idp_name,json=idpName,proto3" json:"idp_name,omitempty"` // The authorization code from the SSO provider. Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"` // The redirect URI used in the SSO flow. @@ -483,11 +484,11 @@ func (*SignInRequest_SSOCredentials) Descriptor() ([]byte, []int) { return file_api_v1_auth_service_proto_rawDescGZIP(), []int{2, 1} } -func (x *SignInRequest_SSOCredentials) GetIdpId() int32 { +func (x *SignInRequest_SSOCredentials) GetIdpName() string { if x != nil { - return x.IdpId + return x.IdpName } - return 0 + return "" } func (x *SignInRequest_SSOCredentials) GetCode() string { @@ -518,15 +519,15 @@ const file_api_v1_auth_service_proto_rawDesc = "" + "\x19api/v1/auth_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\x17\n" + "\x15GetCurrentUserRequest\"@\n" + "\x16GetCurrentUserResponse\x12&\n" + - "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xce\x03\n" + + "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xd2\x03\n" + "\rSignInRequest\x12d\n" + "\x14password_credentials\x18\x01 \x01(\v2/.memos.api.v1.SignInRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12U\n" + "\x0fsso_credentials\x18\x02 \x01(\v2*.memos.api.v1.SignInRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" + "\x13PasswordCredentials\x12\x1f\n" + "\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" + - "\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" + - "\x0eSSOCredentials\x12\x1a\n" + - "\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" + + "\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x9b\x01\n" + + "\x0eSSOCredentials\x12\x1e\n" + + "\bidp_name\x18\x01 \x01(\tB\x03\xe0A\x02R\aidpName\x12\x17\n" + "\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" + "\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" + "\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" + diff --git a/proto/gen/api/v1/auth_service_grpc.pb.go b/proto/gen/api/v1/auth_service_grpc.pb.go index b9cef0215..f7258cb70 100644 --- a/proto/gen/api/v1/auth_service_grpc.pb.go +++ b/proto/gen/api/v1/auth_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/auth_service.proto diff --git a/proto/gen/api/v1/idp_service_grpc.pb.go b/proto/gen/api/v1/idp_service_grpc.pb.go index bce241228..8af2d3698 100644 --- a/proto/gen/api/v1/idp_service_grpc.pb.go +++ b/proto/gen/api/v1/idp_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/idp_service.proto diff --git a/proto/gen/api/v1/instance_service_grpc.pb.go b/proto/gen/api/v1/instance_service_grpc.pb.go index aaa18dd3e..5381a4ead 100644 --- a/proto/gen/api/v1/instance_service_grpc.pb.go +++ b/proto/gen/api/v1/instance_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/instance_service.proto diff --git a/proto/gen/api/v1/memo_service_grpc.pb.go b/proto/gen/api/v1/memo_service_grpc.pb.go index a7d4f46e8..bcb9adfe4 100644 --- a/proto/gen/api/v1/memo_service_grpc.pb.go +++ b/proto/gen/api/v1/memo_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/memo_service.proto diff --git a/proto/gen/api/v1/shortcut_service_grpc.pb.go b/proto/gen/api/v1/shortcut_service_grpc.pb.go index f6913ef07..787835a52 100644 --- a/proto/gen/api/v1/shortcut_service_grpc.pb.go +++ b/proto/gen/api/v1/shortcut_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/shortcut_service.proto diff --git a/proto/gen/api/v1/user_service_grpc.pb.go b/proto/gen/api/v1/user_service_grpc.pb.go index cef54be93..09d5274e7 100644 --- a/proto/gen/api/v1/user_service_grpc.pb.go +++ b/proto/gen/api/v1/user_service_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 +// - protoc-gen-go-grpc v1.6.1 // - protoc (unknown) // source: api/v1/user_service.proto diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index d8ed9eed1..83ff67f13 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -2757,15 +2757,16 @@ components: description: Nested message for password-based authentication credentials. SignInRequest_SSOCredentials: required: - - idpId + - idpName - code - redirectUri type: object properties: - idpId: - type: integer - description: The ID of the SSO provider. - format: int32 + idpName: + type: string + description: |- + The resource name of the SSO provider. + Format: identity-providers/{uid} code: type: string description: The authorization code from the SSO provider. diff --git a/proto/gen/store/idp.pb.go b/proto/gen/store/idp.pb.go index 074fc9c04..6791ce1c9 100644 --- a/proto/gen/store/idp.pb.go +++ b/proto/gen/store/idp.pb.go @@ -74,6 +74,7 @@ type IdentityProvider struct { Type IdentityProvider_Type `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.IdentityProvider_Type" json:"type,omitempty"` IdentifierFilter string `protobuf:"bytes,4,opt,name=identifier_filter,json=identifierFilter,proto3" json:"identifier_filter,omitempty"` Config *IdentityProviderConfig `protobuf:"bytes,5,opt,name=config,proto3" json:"config,omitempty"` + Uid string `protobuf:"bytes,6,opt,name=uid,proto3" json:"uid,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -143,6 +144,13 @@ func (x *IdentityProvider) GetConfig() *IdentityProviderConfig { return nil } +func (x *IdentityProvider) GetUid() string { + if x != nil { + return x.Uid + } + return "" +} + type IdentityProviderConfig struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Config: @@ -373,13 +381,14 @@ var File_store_idp_proto protoreflect.FileDescriptor const file_store_idp_proto_rawDesc = "" + "\n" + - "\x0fstore/idp.proto\x12\vmemos.store\"\x82\x02\n" + + "\x0fstore/idp.proto\x12\vmemos.store\"\x94\x02\n" + "\x10IdentityProvider\x12\x0e\n" + "\x02id\x18\x01 \x01(\x05R\x02id\x12\x12\n" + "\x04name\x18\x02 \x01(\tR\x04name\x126\n" + "\x04type\x18\x03 \x01(\x0e2\".memos.store.IdentityProvider.TypeR\x04type\x12+\n" + "\x11identifier_filter\x18\x04 \x01(\tR\x10identifierFilter\x12;\n" + - "\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\"(\n" + + "\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\x12\x10\n" + + "\x03uid\x18\x06 \x01(\tR\x03uid\"(\n" + "\x04Type\x12\x14\n" + "\x10TYPE_UNSPECIFIED\x10\x00\x12\n" + "\n" + diff --git a/proto/store/idp.proto b/proto/store/idp.proto index 990376cda..6f69b1dd1 100644 --- a/proto/store/idp.proto +++ b/proto/store/idp.proto @@ -15,6 +15,7 @@ message IdentityProvider { Type type = 3; string identifier_filter = 4; IdentityProviderConfig config = 5; + string uid = 6; } message IdentityProviderConfig { diff --git a/server/router/api/v1/attachment_service.go b/server/router/api/v1/attachment_service.go index f2218da41..46837814c 100644 --- a/server/router/api/v1/attachment_service.go +++ b/server/router/api/v1/attachment_service.go @@ -16,7 +16,6 @@ import ( "time" "github.com/disintegration/imaging" - "github.com/lithammer/shortuuid/v4" "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -100,10 +99,9 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format") } - // Use provided attachment_id or generate a new one - attachmentUID := request.AttachmentId - if attachmentUID == "" { - attachmentUID = shortuuid.New() + attachmentUID, err := ValidateAndGenerateUID(request.AttachmentId) + if err != nil { + return nil, err } create := &store.Attachment{ diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index e3f41a2a3..87cc55a63 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -90,8 +90,12 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) existingUser = user } else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil { // Authentication Method 2: SSO (OAuth2) authentication + idpUID, err := ExtractIdentityProviderUIDFromName(ssoCredentials.IdpName) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) + } identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ - ID: &ssoCredentials.IdpId, + UID: &idpUID, }) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err) diff --git a/server/router/api/v1/idp_service.go b/server/router/api/v1/idp_service.go index d257a49b5..b6b65b283 100644 --- a/server/router/api/v1/idp_service.go +++ b/server/router/api/v1/idp_service.go @@ -25,7 +25,15 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider)) + idpUID, err := ValidateAndGenerateUID(request.IdentityProviderId) + if err != nil { + return nil, err + } + + storeIdp := convertIdentityProviderToStore(request.IdentityProvider) + storeIdp.Uid = idpUID + + identityProvider, err := s.Store.CreateIdentityProvider(ctx, storeIdp) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err) } @@ -57,12 +65,12 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId } func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) { - id, err := ExtractIdentityProviderIDFromName(request.Name) + uid, err := ExtractIdentityProviderUIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) } identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ - ID: &id, + UID: &uid, }) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err) @@ -98,12 +106,22 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb return nil, status.Errorf(codes.InvalidArgument, "update_mask is required") } - id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name) + uid, err := ExtractIdentityProviderUIDFromName(request.IdentityProvider.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) } + + // Look up the IdP by UID to get the internal ID for update. + existing, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err) + } + if existing == nil { + return nil, status.Errorf(codes.NotFound, "identity provider not found") + } + update := &store.UpdateIdentityProviderV1{ - ID: id, + ID: existing.Id, Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]), } for _, field := range request.UpdateMask.Paths { @@ -138,13 +156,13 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - id, err := ExtractIdentityProviderIDFromName(request.Name) + uid, err := ExtractIdentityProviderUIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) } - // Check if the identity provider exists before trying to delete it - identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id}) + // Look up the IdP by UID to get the internal ID for deletion. + identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid}) if err != nil { return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err) } @@ -152,7 +170,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb return nil, status.Errorf(codes.NotFound, "identity provider not found") } - if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil { + if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProvider.Id}); err != nil { return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err) } return &emptypb.Empty{}, nil @@ -160,7 +178,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider { temp := &v1pb.IdentityProvider{ - Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id), + Name: fmt.Sprintf("%s%s", IdentityProviderNamePrefix, identityProvider.Uid), Title: identityProvider.Name, IdentifierFilter: identityProvider.IdentifierFilter, Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]), @@ -190,10 +208,7 @@ func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider } func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider { - id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name) - temp := &storepb.IdentityProvider{ - Id: id, Name: identityProvider.Title, IdentifierFilter: identityProvider.IdentifierFilter, Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]), diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index b1d881e2c..ee82ea22e 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -7,13 +7,11 @@ import ( "strings" "time" - "github.com/lithammer/shortuuid/v4" "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" - "github.com/usememos/memos/internal/base" "github.com/usememos/memos/plugin/webhook" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" @@ -30,13 +28,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") } - // Use custom memo_id if provided, otherwise generate a new UUID - memoUID := strings.TrimSpace(request.MemoId) - if memoUID == "" { - memoUID = shortuuid.New() - } else if !base.UIDMatcher.MatchString(memoUID) { - // Validate custom memo ID format - return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen") + memoUID, err := ValidateAndGenerateUID(request.MemoId) + if err != nil { + return nil, err } create := &store.Memo{ diff --git a/server/router/api/v1/resource_name.go b/server/router/api/v1/resource_name.go index c04bbecb4..d201416e1 100644 --- a/server/router/api/v1/resource_name.go +++ b/server/router/api/v1/resource_name.go @@ -4,8 +4,12 @@ import ( "fmt" "strings" + "github.com/lithammer/shortuuid/v4" "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/usememos/memos/internal/base" "github.com/usememos/memos/internal/util" ) @@ -133,16 +137,12 @@ func ExtractInboxIDFromName(name string) (int32, error) { return id, nil } -func ExtractIdentityProviderIDFromName(name string) (int32, error) { +func ExtractIdentityProviderUIDFromName(name string) (string, error) { tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix) if err != nil { - return 0, err + return "", err } - id, err := util.ConvertStringToInt32(tokens[0]) - if err != nil { - return 0, errors.Errorf("invalid identity provider ID %q", tokens[0]) - } - return id, nil + return tokens[0], nil } func ExtractActivityIDFromName(name string) (int32, error) { @@ -156,3 +156,17 @@ func ExtractActivityIDFromName(name string) (int32, error) { } return id, nil } + +// ValidateAndGenerateUID validates a user-provided UID or generates a new one. +// If provided is empty, a new shortuuid is generated. +// If provided is non-empty, it is validated against base.UIDMatcher. +func ValidateAndGenerateUID(provided string) (string, error) { + uid := strings.TrimSpace(provided) + if uid == "" { + return shortuuid.New(), nil + } + if !base.UIDMatcher.MatchString(uid) { + return "", status.Errorf(codes.InvalidArgument, "invalid ID format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen") + } + return uid, nil +} diff --git a/store/db/mysql/idp.go b/store/db/mysql/idp.go index 4f0ba4902..9af3f8046 100644 --- a/store/db/mysql/idp.go +++ b/store/db/mysql/idp.go @@ -11,9 +11,9 @@ import ( ) func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { - placeholders := []string{"?", "?", "?", "?"} - fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} - args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config} + placeholders := []string{"?", "?", "?", "?", "?"} + fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"} + args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config} stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")" result, err := d.db.ExecContext(ctx, stmt, args...) @@ -35,8 +35,11 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity if v := find.ID; v != nil { where, args = append(where, "`id` = ?"), append(args, *v) } + if v := find.UID; v != nil { + where, args = append(where, "`uid` = ?"), append(args, *v) + } - rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC", + rows, err := d.db.QueryContext(ctx, "SELECT `id`, `uid`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC", args..., ) if err != nil { @@ -50,6 +53,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity var typeString string if err := rows.Scan( &identityProvider.ID, + &identityProvider.UID, &identityProvider.Name, &typeString, &identityProvider.IdentifierFilter, diff --git a/store/db/postgres/idp.go b/store/db/postgres/idp.go index 34a1165b9..2cfce4879 100644 --- a/store/db/postgres/idp.go +++ b/store/db/postgres/idp.go @@ -9,8 +9,8 @@ import ( ) func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { - fields := []string{"name", "type", "identifier_filter", "config"} - args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config} + fields := []string{"uid", "name", "type", "identifier_filter", "config"} + args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config} stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id" if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { return nil, err @@ -25,10 +25,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity if v := find.ID; v != nil { where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } + if v := find.UID; v != nil { + where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v) + } rows, err := d.db.QueryContext(ctx, ` SELECT id, + uid, name, type, identifier_filter, @@ -48,6 +52,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity var typeString string if err := rows.Scan( &identityProvider.ID, + &identityProvider.UID, &identityProvider.Name, &typeString, &identityProvider.IdentifierFilter, @@ -83,7 +88,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde UPDATE idp SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) + ` - RETURNING id, name, type, identifier_filter, config + RETURNING id, uid, name, type, identifier_filter, config ` args = append(args, update.ID) @@ -91,6 +96,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde var typeString string if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( &identityProvider.ID, + &identityProvider.UID, &identityProvider.Name, &typeString, &identityProvider.IdentifierFilter, diff --git a/store/db/sqlite/idp.go b/store/db/sqlite/idp.go index 608365287..41c6b5d00 100644 --- a/store/db/sqlite/idp.go +++ b/store/db/sqlite/idp.go @@ -10,9 +10,9 @@ import ( ) func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { - placeholders := []string{"?", "?", "?", "?"} - fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} - args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config} + placeholders := []string{"?", "?", "?", "?", "?"} + fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"} + args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config} stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`" if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { @@ -28,10 +28,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity if v := find.ID; v != nil { where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) } + if v := find.UID; v != nil { + where, args = append(where, fmt.Sprintf("uid = $%d", len(args)+1)), append(args, *v) + } rows, err := d.db.QueryContext(ctx, ` SELECT id, + uid, name, type, identifier_filter, @@ -51,6 +55,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity var typeString string if err := rows.Scan( &identityProvider.ID, + &identityProvider.UID, &identityProvider.Name, &typeString, &identityProvider.IdentifierFilter, @@ -86,12 +91,13 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde UPDATE idp SET ` + strings.Join(set, ", ") + ` WHERE id = ? - RETURNING id, name, type, identifier_filter, config + RETURNING id, uid, name, type, identifier_filter, config ` var identityProvider store.IdentityProvider var typeString string if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( &identityProvider.ID, + &identityProvider.UID, &identityProvider.Name, &typeString, &identityProvider.IdentifierFilter, diff --git a/store/idp.go b/store/idp.go index 88ab6f0e3..15ac7141f 100644 --- a/store/idp.go +++ b/store/idp.go @@ -11,6 +11,7 @@ import ( type IdentityProvider struct { ID int32 + UID string Name string Type storepb.IdentityProvider_Type IdentifierFilter string @@ -18,7 +19,8 @@ type IdentityProvider struct { } type FindIdentityProvider struct { - ID *int32 + ID *int32 + UID *string } type UpdateIdentityProvider struct { @@ -130,6 +132,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) { identityProvider := &storepb.IdentityProvider{ Id: raw.ID, + Uid: raw.UID, Name: raw.Name, Type: raw.Type, IdentifierFilter: raw.IdentifierFilter, @@ -145,6 +148,7 @@ func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityPro func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) { raw := &IdentityProvider{ ID: identityProvider.Id, + UID: identityProvider.Uid, Name: identityProvider.Name, Type: identityProvider.Type, IdentifierFilter: identityProvider.IdentifierFilter, diff --git a/store/migration/mysql/0.27/01__add_idp_uid.sql b/store/migration/mysql/0.27/01__add_idp_uid.sql new file mode 100644 index 000000000..a3e824731 --- /dev/null +++ b/store/migration/mysql/0.27/01__add_idp_uid.sql @@ -0,0 +1,8 @@ +-- Add uid column to idp table +ALTER TABLE `idp` ADD COLUMN `uid` VARCHAR(256) NOT NULL DEFAULT ''; + +-- Populate uid for existing rows using hex of id as a fallback +UPDATE `idp` SET `uid` = LOWER(LPAD(HEX(`id`), 8, '0')) WHERE `uid` = ''; + +-- Create unique index on uid +ALTER TABLE `idp` ADD UNIQUE INDEX `idx_idp_uid` (`uid`); diff --git a/store/migration/mysql/LATEST.sql b/store/migration/mysql/LATEST.sql index a76d7111b..017854c27 100644 --- a/store/migration/mysql/LATEST.sql +++ b/store/migration/mysql/LATEST.sql @@ -80,6 +80,7 @@ CREATE TABLE `activity` ( -- idp CREATE TABLE `idp` ( `id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + `uid` VARCHAR(256) NOT NULL UNIQUE, `name` TEXT NOT NULL, `type` TEXT NOT NULL, `identifier_filter` VARCHAR(256) NOT NULL DEFAULT '', diff --git a/store/migration/postgres/0.27/01__add_idp_uid.sql b/store/migration/postgres/0.27/01__add_idp_uid.sql new file mode 100644 index 000000000..d6ac67fea --- /dev/null +++ b/store/migration/postgres/0.27/01__add_idp_uid.sql @@ -0,0 +1,8 @@ +-- Add uid column to idp table +ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT ''; + +-- Populate uid for existing rows using hex of id as a fallback +UPDATE idp SET uid = LPAD(TO_HEX(id), 8, '0') WHERE uid = ''; + +-- Create unique index on uid +CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid); diff --git a/store/migration/postgres/LATEST.sql b/store/migration/postgres/LATEST.sql index cbde126cd..b5faf9f38 100644 --- a/store/migration/postgres/LATEST.sql +++ b/store/migration/postgres/LATEST.sql @@ -80,6 +80,7 @@ CREATE TABLE activity ( -- idp CREATE TABLE idp ( id SERIAL PRIMARY KEY, + uid TEXT NOT NULL UNIQUE, name TEXT NOT NULL, type TEXT NOT NULL, identifier_filter TEXT NOT NULL DEFAULT '', diff --git a/store/migration/sqlite/0.27/01__add_idp_uid.sql b/store/migration/sqlite/0.27/01__add_idp_uid.sql new file mode 100644 index 000000000..4a6d95233 --- /dev/null +++ b/store/migration/sqlite/0.27/01__add_idp_uid.sql @@ -0,0 +1,8 @@ +-- Add uid column to idp table +ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT ''; + +-- Populate uid for existing rows using hex of id as a fallback +UPDATE idp SET uid = printf('%08x', id) WHERE uid = ''; + +-- Create unique index on uid +CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid); diff --git a/store/migration/sqlite/LATEST.sql b/store/migration/sqlite/LATEST.sql index 8b70fa68c..8daa49a11 100644 --- a/store/migration/sqlite/LATEST.sql +++ b/store/migration/sqlite/LATEST.sql @@ -81,6 +81,7 @@ CREATE TABLE activity ( -- idp CREATE TABLE idp ( id INTEGER PRIMARY KEY AUTOINCREMENT, + uid TEXT NOT NULL UNIQUE, name TEXT NOT NULL, type TEXT NOT NULL, identifier_filter TEXT NOT NULL DEFAULT '', diff --git a/store/test/idp_test.go b/store/test/idp_test.go index 8f2f1958b..79ca1912f 100644 --- a/store/test/idp_test.go +++ b/store/test/idp_test.go @@ -15,6 +15,7 @@ func TestIdentityProviderStore(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Uid: "test-github-oauth", Name: "GitHub OAuth", Type: storepb.IdentityProvider_OAUTH2, IdentifierFilter: "", @@ -37,6 +38,7 @@ func TestIdentityProviderStore(t *testing.T) { }, }) require.NoError(t, err) + require.Equal(t, "test-github-oauth", createdIDP.Uid) idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ ID: &createdIDP.Id, }) @@ -66,7 +68,7 @@ func TestIdentityProviderGetByID(t *testing.T) { ts := NewTestingStore(ctx, t) // Create IDP - idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp")) require.NoError(t, err) // Get by ID @@ -76,6 +78,13 @@ func TestIdentityProviderGetByID(t *testing.T) { require.Equal(t, idp.Id, found.Id) require.Equal(t, idp.Name, found.Name) + // Get by UID + foundByUID, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &idp.Uid}) + require.NoError(t, err) + require.NotNil(t, foundByUID) + require.Equal(t, idp.Id, foundByUID.Id) + require.Equal(t, idp.Uid, foundByUID.Uid) + // Get by non-existent ID nonExistentID := int32(99999) notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID}) @@ -91,11 +100,11 @@ func TestIdentityProviderListMultiple(t *testing.T) { ts := NewTestingStore(ctx, t) // Create multiple IDPs - _, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth")) + _, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth")) require.NoError(t, err) - _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth")) + _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth")) require.NoError(t, err) - _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth")) + _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth", "gitlab-oauth")) require.NoError(t, err) // List all @@ -112,9 +121,9 @@ func TestIdentityProviderListByID(t *testing.T) { ts := NewTestingStore(ctx, t) // Create multiple IDPs - idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth")) + idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth")) require.NoError(t, err) - _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth")) + _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth")) require.NoError(t, err) // List by specific ID @@ -131,7 +140,7 @@ func TestIdentityProviderUpdateName(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name")) + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name", "original-name")) require.NoError(t, err) require.Equal(t, "Original Name", idp.Name) @@ -158,7 +167,7 @@ func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp")) require.NoError(t, err) require.Equal(t, "", idp.IdentifierFilter) @@ -185,7 +194,7 @@ func TestIdentityProviderUpdateConfig(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp")) require.NoError(t, err) // Update config @@ -229,7 +238,7 @@ func TestIdentityProviderUpdateMultipleFields(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original")) + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original", "original")) require.NoError(t, err) // Update multiple fields at once @@ -253,7 +262,7 @@ func TestIdentityProviderDelete(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) + idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp")) require.NoError(t, err) // Delete @@ -274,9 +283,9 @@ func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) { ts := NewTestingStore(ctx, t) // Create multiple IDPs - idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1")) + idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1", "idp-1")) require.NoError(t, err) - idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2")) + idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2", "idp-2")) require.NoError(t, err) // Delete first one @@ -304,6 +313,7 @@ func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) { // Create IDP with multiple scopes idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Uid: "multi-scope-oauth", Name: "Multi-Scope OAuth", Type: storepb.IdentityProvider_OAUTH2, Config: &storepb.IdentityProviderConfig{ @@ -343,6 +353,7 @@ func TestIdentityProviderFieldMapping(t *testing.T) { // Create IDP with custom field mapping idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Uid: "custom-field-mapping", Name: "Custom Field Mapping", Type: storepb.IdentityProvider_OAUTH2, Config: &storepb.IdentityProviderConfig{ @@ -382,17 +393,19 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) { testCases := []struct { name string + uid string filter string }{ - {"Domain filter", "@company\\.com$"}, - {"Prefix filter", "^admin_"}, - {"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"}, - {"Empty filter", ""}, + {"Domain filter", "domain-filter", "@company\\.com$"}, + {"Prefix filter", "prefix-filter", "^admin_"}, + {"Complex regex", "complex-regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"}, + {"Empty filter", "empty-filter", ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ + Uid: tc.uid, Name: tc.name, Type: storepb.IdentityProvider_OAUTH2, IdentifierFilter: tc.filter, @@ -428,8 +441,9 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) { } // Helper function to create a test OAuth2 IDP. -func createTestOAuth2IDP(name string) *storepb.IdentityProvider { +func createTestOAuth2IDP(name, uid string) *storepb.IdentityProvider { return &storepb.IdentityProvider{ + Uid: uid, Name: name, Type: storepb.IdentityProvider_OAUTH2, IdentifierFilter: "", diff --git a/web/src/components/CreateIdentityProviderDialog.tsx b/web/src/components/CreateIdentityProviderDialog.tsx index da4ca5386..c8569d83f 100644 --- a/web/src/components/CreateIdentityProviderDialog.tsx +++ b/web/src/components/CreateIdentityProviderDialog.tsx @@ -133,6 +133,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on const identityProviderTypes = [...new Set(templateList.map((t) => t.type))]; const [basicInfo, setBasicInfo] = useState({ title: "", + identifier: "", identifierFilter: "", }); const [type, setType] = useState(IdentityProvider_Type.OAUTH2); @@ -161,6 +162,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on // Reset to default state when dialog is closed setBasicInfo({ title: "", + identifier: "", identifierFilter: "", }); setType(IdentityProvider_Type.OAUTH2); @@ -189,6 +191,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on if (open && identityProvider) { setBasicInfo({ title: identityProvider.title, + identifier: "", identifierFilter: identityProvider.identifierFilter, }); setType(identityProvider.type); @@ -210,6 +213,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on if (template) { setBasicInfo({ title: template.title, + identifier: template.title.toLowerCase().replace(/[^a-z0-9]+/g, "-"), identifierFilter: template.identifierFilter, }); setType(template.type); @@ -229,6 +233,9 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on if (basicInfo.title === "") { return false; } + if (isCreating && basicInfo.identifier === "") { + return false; + } if (type === IdentityProvider_Type.OAUTH2) { if ( oauth2Config.clientId === "" || @@ -254,8 +261,10 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on try { if (isCreating) { await identityProviderServiceClient.createIdentityProvider({ + identityProviderId: basicInfo.identifier, identityProvider: create(IdentityProviderSchema, { - ...basicInfo, + title: basicInfo.title, + identifierFilter: basicInfo.identifierFilter, type: type, config: create(IdentityProviderConfigSchema, { config: { @@ -343,6 +352,32 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on )} + {isCreating && ( + <> +

+ ID + * +

+ + setBasicInfo({ + ...basicInfo, + identifier: e.target.value + .toLowerCase() + .replace(/[^a-z0-9-]/g, "-") + .replace(/--+/g, "-"), + }) + } + /> +

+ A unique identifier for this provider. Lowercase letters, numbers, and hyphens only. +

+ + )}

{t("common.name")} * diff --git a/web/src/helpers/resource-names.ts b/web/src/helpers/resource-names.ts index e3cc8f067..3e9d0bfcc 100644 --- a/web/src/helpers/resource-names.ts +++ b/web/src/helpers/resource-names.ts @@ -16,8 +16,8 @@ export const extractMemoIdFromName = (name: string) => { return name.split(memoNamePrefix).pop() || ""; }; -export const extractIdentityProviderIdFromName = (name: string) => { - return parseInt(name.split(identityProviderNamePrefix).pop() || "", 10); +export const extractIdentityProviderUidFromName = (name: string) => { + return name.split(identityProviderNamePrefix).pop() || ""; }; // Helper function to convert InstanceSetting_Key enum value to string name diff --git a/web/src/pages/AuthCallback.tsx b/web/src/pages/AuthCallback.tsx index bc79ef2bd..eb46fb091 100644 --- a/web/src/pages/AuthCallback.tsx +++ b/web/src/pages/AuthCallback.tsx @@ -72,7 +72,7 @@ const AuthCallback = () => { return; } - const { identityProviderId, returnUrl, codeVerifier } = validatedState; + const { identityProviderName, returnUrl, codeVerifier } = validatedState; const redirectUri = absolutifyLink("/auth/callback"); (async () => { @@ -81,7 +81,7 @@ const AuthCallback = () => { credentials: { case: "ssoCredentials", value: { - idpId: identityProviderId, + idpName: identityProviderName, code, redirectUri, codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange diff --git a/web/src/pages/SignIn.tsx b/web/src/pages/SignIn.tsx index 2b39cc163..2553e5072 100644 --- a/web/src/pages/SignIn.tsx +++ b/web/src/pages/SignIn.tsx @@ -7,7 +7,6 @@ import { Button } from "@/components/ui/button"; import { Separator } from "@/components/ui/separator"; import { identityProviderServiceClient } from "@/connect"; import { useInstance } from "@/contexts/InstanceContext"; -import { extractIdentityProviderIdFromName } from "@/helpers/resource-names"; import { absolutifyLink } from "@/helpers/utils"; import useCurrentUser from "@/hooks/useCurrentUser"; import { handleError } from "@/lib/error"; @@ -50,8 +49,7 @@ const SignIn = () => { try { // Generate and store secure state parameter with CSRF protection // Also generate PKCE parameters (code_challenge) for enhanced security if available - const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name); - const { state, codeChallenge } = await storeOAuthState(identityProviderId); + const { state, codeChallenge } = await storeOAuthState(identityProvider.name); // Build OAuth authorization URL with secure state // Include PKCE if available (requires HTTPS/localhost for crypto.subtle) diff --git a/web/src/types/proto/api/v1/auth_service_pb.ts b/web/src/types/proto/api/v1/auth_service_pb.ts index b54f4cb68..972f3f6ac 100644 --- a/web/src/types/proto/api/v1/auth_service_pb.ts +++ b/web/src/types/proto/api/v1/auth_service_pb.ts @@ -16,7 +16,7 @@ import type { Message } from "@bufbuild/protobuf"; * Describes the file api/v1/auth_service.proto. */ export const file_api_v1_auth_service: GenFile = /*@__PURE__*/ - fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIuwCCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIabwoOU1NPQ3JlZGVudGlhbHMSEwoGaWRwX2lkGAEgASgFQgPgQQISEQoEY29kZRgCIAEoCUID4EECEhkKDHJlZGlyZWN0X3VyaRgDIAEoCUID4EECEhoKDWNvZGVfdmVyaWZpZXIYBCABKAlCA+BBAUINCgtjcmVkZW50aWFscyKFAQoOU2lnbkluUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyEhQKDGFjY2Vzc190b2tlbhgCIAEoCRI7ChdhY2Nlc3NfdG9rZW5fZXhwaXJlc19hdBgDIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAiEAoOU2lnbk91dFJlcXVlc3QiFQoTUmVmcmVzaFRva2VuUmVxdWVzdCJcChRSZWZyZXNoVG9rZW5SZXNwb25zZRIUCgxhY2Nlc3NfdG9rZW4YASABKAkSLgoKZXhwaXJlc19hdBgCIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAyvwMKC0F1dGhTZXJ2aWNlEnQKDkdldEN1cnJlbnRVc2VyEiMubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVxdWVzdBokLm1lbW9zLmFwaS52MS5HZXRDdXJyZW50VXNlclJlc3BvbnNlIheC0+STAhESDy9hcGkvdjEvYXV0aC9tZRJjCgZTaWduSW4SGy5tZW1vcy5hcGkudjEuU2lnbkluUmVxdWVzdBocLm1lbW9zLmFwaS52MS5TaWduSW5SZXNwb25zZSIegtPkkwIYOgEqIhMvYXBpL3YxL2F1dGgvc2lnbmluEl0KB1NpZ25PdXQSHC5tZW1vcy5hcGkudjEuU2lnbk91dFJlcXVlc3QaFi5nb29nbGUucHJvdG9idWYuRW1wdHkiHILT5JMCFiIUL2FwaS92MS9hdXRoL3NpZ25vdXQSdgoMUmVmcmVzaFRva2VuEiEubWVtb3MuYXBpLnYxLlJlZnJlc2hUb2tlblJlcXVlc3QaIi5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVzcG9uc2UiH4LT5JMCGToBKiIUL2FwaS92MS9hdXRoL3JlZnJlc2hCqAEKEGNvbS5tZW1vcy5hcGkudjFCEEF1dGhTZXJ2aWNlUHJvdG9QAVowZ2l0aHViLmNvbS91c2VtZW1vcy9tZW1vcy9wcm90by9nZW4vYXBpL3YxO2FwaXYxogIDTUFYqgIMTWVtb3MuQXBpLlYxygIMTWVtb3NcQXBpXFYx4gIYTWVtb3NcQXBpXFYxXEdQQk1ldGFkYXRh6gIOTWVtb3M6OkFwaTo6VjFiBnByb3RvMw", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]); + fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIu4CCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIacQoOU1NPQ3JlZGVudGlhbHMSFQoIaWRwX25hbWUYASABKAlCA+BBAhIRCgRjb2RlGAIgASgJQgPgQQISGQoMcmVkaXJlY3RfdXJpGAMgASgJQgPgQQISGgoNY29kZV92ZXJpZmllchgEIAEoCUID4EEBQg0KC2NyZWRlbnRpYWxzIoUBCg5TaWduSW5SZXNwb25zZRIgCgR1c2VyGAEgASgLMhIubWVtb3MuYXBpLnYxLlVzZXISFAoMYWNjZXNzX3Rva2VuGAIgASgJEjsKF2FjY2Vzc190b2tlbl9leHBpcmVzX2F0GAMgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcCIQCg5TaWduT3V0UmVxdWVzdCIVChNSZWZyZXNoVG9rZW5SZXF1ZXN0IlwKFFJlZnJlc2hUb2tlblJlc3BvbnNlEhQKDGFjY2Vzc190b2tlbhgBIAEoCRIuCgpleHBpcmVzX2F0GAIgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcDK/AwoLQXV0aFNlcnZpY2USdAoOR2V0Q3VycmVudFVzZXISIy5tZW1vcy5hcGkudjEuR2V0Q3VycmVudFVzZXJSZXF1ZXN0GiQubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVzcG9uc2UiF4LT5JMCERIPL2FwaS92MS9hdXRoL21lEmMKBlNpZ25JbhIbLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0GhwubWVtb3MuYXBpLnYxLlNpZ25JblJlc3BvbnNlIh6C0+STAhg6ASoiEy9hcGkvdjEvYXV0aC9zaWduaW4SXQoHU2lnbk91dBIcLm1lbW9zLmFwaS52MS5TaWduT3V0UmVxdWVzdBoWLmdvb2dsZS5wcm90b2J1Zi5FbXB0eSIcgtPkkwIWIhQvYXBpL3YxL2F1dGgvc2lnbm91dBJ2CgxSZWZyZXNoVG9rZW4SIS5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVxdWVzdBoiLm1lbW9zLmFwaS52MS5SZWZyZXNoVG9rZW5SZXNwb25zZSIfgtPkkwIZOgEqIhQvYXBpL3YxL2F1dGgvcmVmcmVzaEKoAQoQY29tLm1lbW9zLmFwaS52MUIQQXV0aFNlcnZpY2VQcm90b1ABWjBnaXRodWIuY29tL3VzZW1lbW9zL21lbW9zL3Byb3RvL2dlbi9hcGkvdjE7YXBpdjGiAgNNQViqAgxNZW1vcy5BcGkuVjHKAgxNZW1vc1xBcGlcVjHiAhhNZW1vc1xBcGlcVjFcR1BCTWV0YWRhdGHqAg5NZW1vczo6QXBpOjpWMWIGcHJvdG8z", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]); /** * @generated from message memos.api.v1.GetCurrentUserRequest @@ -120,11 +120,12 @@ export const SignInRequest_PasswordCredentialsSchema: GenMessage & { /** - * The ID of the SSO provider. + * The resource name of the SSO provider. + * Format: identity-providers/{uid} * - * @generated from field: int32 idp_id = 1; + * @generated from field: string idp_name = 1; */ - idpId: number; + idpName: string; /** * The authorization code from the SSO provider. diff --git a/web/src/utils/oauth.ts b/web/src/utils/oauth.ts index a3cb5c9a4..d3bb8f4e9 100644 --- a/web/src/utils/oauth.ts +++ b/web/src/utils/oauth.ts @@ -3,7 +3,7 @@ const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes interface OAuthState { state: string; - identityProviderId: number; + identityProviderName: string; timestamp: number; returnUrl?: string; codeVerifier?: string; // PKCE code_verifier @@ -42,7 +42,10 @@ function base64UrlEncode(buffer: Uint8Array): string { // Store OAuth state and PKCE parameters in sessionStorage // Returns state and optional codeChallenge for use in authorization URL // PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth -export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge?: string }> { +export async function storeOAuthState( + identityProviderName: string, + returnUrl?: string, +): Promise<{ state: string; codeChallenge?: string }> { const state = generateSecureState(); // Try to generate PKCE parameters if crypto.subtle is available (HTTPS/localhost) @@ -70,7 +73,7 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st const stateData: OAuthState = { state, - identityProviderId, + identityProviderName, timestamp: Date.now(), returnUrl, codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available) @@ -87,8 +90,8 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st } // Validate and retrieve OAuth state from storage (CSRF protection) -// Returns identityProviderId, returnUrl, and codeVerifier for PKCE -export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null { +// Returns identityProviderName, returnUrl, and codeVerifier for PKCE +export function validateOAuthState(stateParam: string): { identityProviderName: string; returnUrl?: string; codeVerifier?: string } | null { try { const storedData = sessionStorage.getItem(STATE_STORAGE_KEY); if (!storedData) { @@ -115,7 +118,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu // State is valid, clean up and return data sessionStorage.removeItem(STATE_STORAGE_KEY); return { - identityProviderId: stateData.identityProviderId, + identityProviderName: stateData.identityProviderName, returnUrl: stateData.returnUrl, codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier };