mirror of https://github.com/usememos/memos.git
feat: replace auto-increment ID with UID for identity provider resource names (#5687)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f0c4489468
commit
92d937b1aa
|
|
@ -64,8 +64,9 @@ message SignInRequest {
|
|||
|
||||
// Nested message for SSO authentication credentials.
|
||||
message SSOCredentials {
|
||||
// The ID of the SSO provider.
|
||||
int32 idp_id = 1 [(google.api.field_behavior) = REQUIRED];
|
||||
// The resource name of the SSO provider.
|
||||
// Format: identity-providers/{uid}
|
||||
string idp_name = 1 [(google.api.field_behavior) = REQUIRED];
|
||||
|
||||
// The authorization code from the SSO provider.
|
||||
string code = 2 [(google.api.field_behavior) = REQUIRED];
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/activity_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/attachment_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -440,8 +440,9 @@ func (x *SignInRequest_PasswordCredentials) GetPassword() string {
|
|||
// Nested message for SSO authentication credentials.
|
||||
type SignInRequest_SSOCredentials struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// The ID of the SSO provider.
|
||||
IdpId int32 `protobuf:"varint,1,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"`
|
||||
// The resource name of the SSO provider.
|
||||
// Format: identity-providers/{uid}
|
||||
IdpName string `protobuf:"bytes,1,opt,name=idp_name,json=idpName,proto3" json:"idp_name,omitempty"`
|
||||
// The authorization code from the SSO provider.
|
||||
Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"`
|
||||
// The redirect URI used in the SSO flow.
|
||||
|
|
@ -483,11 +484,11 @@ func (*SignInRequest_SSOCredentials) Descriptor() ([]byte, []int) {
|
|||
return file_api_v1_auth_service_proto_rawDescGZIP(), []int{2, 1}
|
||||
}
|
||||
|
||||
func (x *SignInRequest_SSOCredentials) GetIdpId() int32 {
|
||||
func (x *SignInRequest_SSOCredentials) GetIdpName() string {
|
||||
if x != nil {
|
||||
return x.IdpId
|
||||
return x.IdpName
|
||||
}
|
||||
return 0
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *SignInRequest_SSOCredentials) GetCode() string {
|
||||
|
|
@ -518,15 +519,15 @@ const file_api_v1_auth_service_proto_rawDesc = "" +
|
|||
"\x19api/v1/auth_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\x17\n" +
|
||||
"\x15GetCurrentUserRequest\"@\n" +
|
||||
"\x16GetCurrentUserResponse\x12&\n" +
|
||||
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xce\x03\n" +
|
||||
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xd2\x03\n" +
|
||||
"\rSignInRequest\x12d\n" +
|
||||
"\x14password_credentials\x18\x01 \x01(\v2/.memos.api.v1.SignInRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12U\n" +
|
||||
"\x0fsso_credentials\x18\x02 \x01(\v2*.memos.api.v1.SignInRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" +
|
||||
"\x13PasswordCredentials\x12\x1f\n" +
|
||||
"\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" +
|
||||
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" +
|
||||
"\x0eSSOCredentials\x12\x1a\n" +
|
||||
"\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" +
|
||||
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x9b\x01\n" +
|
||||
"\x0eSSOCredentials\x12\x1e\n" +
|
||||
"\bidp_name\x18\x01 \x01(\tB\x03\xe0A\x02R\aidpName\x12\x17\n" +
|
||||
"\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" +
|
||||
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" +
|
||||
"\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" +
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/auth_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/idp_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/instance_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/memo_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/shortcut_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc (unknown)
|
||||
// source: api/v1/user_service.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -2757,15 +2757,16 @@ components:
|
|||
description: Nested message for password-based authentication credentials.
|
||||
SignInRequest_SSOCredentials:
|
||||
required:
|
||||
- idpId
|
||||
- idpName
|
||||
- code
|
||||
- redirectUri
|
||||
type: object
|
||||
properties:
|
||||
idpId:
|
||||
type: integer
|
||||
description: The ID of the SSO provider.
|
||||
format: int32
|
||||
idpName:
|
||||
type: string
|
||||
description: |-
|
||||
The resource name of the SSO provider.
|
||||
Format: identity-providers/{uid}
|
||||
code:
|
||||
type: string
|
||||
description: The authorization code from the SSO provider.
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ type IdentityProvider struct {
|
|||
Type IdentityProvider_Type `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.IdentityProvider_Type" json:"type,omitempty"`
|
||||
IdentifierFilter string `protobuf:"bytes,4,opt,name=identifier_filter,json=identifierFilter,proto3" json:"identifier_filter,omitempty"`
|
||||
Config *IdentityProviderConfig `protobuf:"bytes,5,opt,name=config,proto3" json:"config,omitempty"`
|
||||
Uid string `protobuf:"bytes,6,opt,name=uid,proto3" json:"uid,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
|
@ -143,6 +144,13 @@ func (x *IdentityProvider) GetConfig() *IdentityProviderConfig {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (x *IdentityProvider) GetUid() string {
|
||||
if x != nil {
|
||||
return x.Uid
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type IdentityProviderConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Types that are valid to be assigned to Config:
|
||||
|
|
@ -373,13 +381,14 @@ var File_store_idp_proto protoreflect.FileDescriptor
|
|||
|
||||
const file_store_idp_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x0fstore/idp.proto\x12\vmemos.store\"\x82\x02\n" +
|
||||
"\x0fstore/idp.proto\x12\vmemos.store\"\x94\x02\n" +
|
||||
"\x10IdentityProvider\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\x05R\x02id\x12\x12\n" +
|
||||
"\x04name\x18\x02 \x01(\tR\x04name\x126\n" +
|
||||
"\x04type\x18\x03 \x01(\x0e2\".memos.store.IdentityProvider.TypeR\x04type\x12+\n" +
|
||||
"\x11identifier_filter\x18\x04 \x01(\tR\x10identifierFilter\x12;\n" +
|
||||
"\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\"(\n" +
|
||||
"\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\x12\x10\n" +
|
||||
"\x03uid\x18\x06 \x01(\tR\x03uid\"(\n" +
|
||||
"\x04Type\x12\x14\n" +
|
||||
"\x10TYPE_UNSPECIFIED\x10\x00\x12\n" +
|
||||
"\n" +
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ message IdentityProvider {
|
|||
Type type = 3;
|
||||
string identifier_filter = 4;
|
||||
IdentityProviderConfig config = 5;
|
||||
string uid = 6;
|
||||
}
|
||||
|
||||
message IdentityProviderConfig {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
|
@ -100,10 +99,9 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
|
|||
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
|
||||
}
|
||||
|
||||
// Use provided attachment_id or generate a new one
|
||||
attachmentUID := request.AttachmentId
|
||||
if attachmentUID == "" {
|
||||
attachmentUID = shortuuid.New()
|
||||
attachmentUID, err := ValidateAndGenerateUID(request.AttachmentId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create := &store.Attachment{
|
||||
|
|
|
|||
|
|
@ -90,8 +90,12 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
|
|||
existingUser = user
|
||||
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
|
||||
// Authentication Method 2: SSO (OAuth2) authentication
|
||||
idpUID, err := ExtractIdentityProviderUIDFromName(ssoCredentials.IdpName)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &ssoCredentials.IdpId,
|
||||
UID: &idpUID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,15 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
|
||||
idpUID, err := ValidateAndGenerateUID(request.IdentityProviderId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeIdp := convertIdentityProviderToStore(request.IdentityProvider)
|
||||
storeIdp.Uid = idpUID
|
||||
|
||||
identityProvider, err := s.Store.CreateIdentityProvider(ctx, storeIdp)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
|
||||
}
|
||||
|
|
@ -57,12 +65,12 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
|
|||
}
|
||||
|
||||
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
uid, err := ExtractIdentityProviderUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &id,
|
||||
UID: &uid,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
|
||||
|
|
@ -98,12 +106,22 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
|
||||
uid, err := ExtractIdentityProviderUIDFromName(request.IdentityProvider.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
|
||||
// Look up the IdP by UID to get the internal ID for update.
|
||||
existing, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
|
||||
}
|
||||
if existing == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
|
||||
update := &store.UpdateIdentityProviderV1{
|
||||
ID: id,
|
||||
ID: existing.Id,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
|
|
@ -138,13 +156,13 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
uid, err := ExtractIdentityProviderUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
|
||||
// Check if the identity provider exists before trying to delete it
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
|
||||
// Look up the IdP by UID to get the internal ID for deletion.
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
|
||||
}
|
||||
|
|
@ -152,7 +170,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
|
|||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
|
||||
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProvider.Id}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
|
|
@ -160,7 +178,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
|
|||
|
||||
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
|
||||
temp := &v1pb.IdentityProvider{
|
||||
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
|
||||
Name: fmt.Sprintf("%s%s", IdentityProviderNamePrefix, identityProvider.Uid),
|
||||
Title: identityProvider.Name,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
|
|
@ -190,10 +208,7 @@ func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider
|
|||
}
|
||||
|
||||
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
|
||||
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
|
||||
|
||||
temp := &storepb.IdentityProvider{
|
||||
Id: id,
|
||||
Name: identityProvider.Title,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
|
|
|
|||
|
|
@ -7,13 +7,11 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
"github.com/usememos/memos/plugin/webhook"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
|
|
@ -30,13 +28,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
|
|||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
// Use custom memo_id if provided, otherwise generate a new UUID
|
||||
memoUID := strings.TrimSpace(request.MemoId)
|
||||
if memoUID == "" {
|
||||
memoUID = shortuuid.New()
|
||||
} else if !base.UIDMatcher.MatchString(memoUID) {
|
||||
// Validate custom memo ID format
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
|
||||
memoUID, err := ValidateAndGenerateUID(request.MemoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create := &store.Memo{
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/usememos/memos/internal/base"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
|
|
@ -133,16 +137,12 @@ func ExtractInboxIDFromName(name string) (int32, error) {
|
|||
return id, nil
|
||||
}
|
||||
|
||||
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
|
||||
func ExtractIdentityProviderUIDFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return "", err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
return tokens[0], nil
|
||||
}
|
||||
|
||||
func ExtractActivityIDFromName(name string) (int32, error) {
|
||||
|
|
@ -156,3 +156,17 @@ func ExtractActivityIDFromName(name string) (int32, error) {
|
|||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ValidateAndGenerateUID validates a user-provided UID or generates a new one.
|
||||
// If provided is empty, a new shortuuid is generated.
|
||||
// If provided is non-empty, it is validated against base.UIDMatcher.
|
||||
func ValidateAndGenerateUID(provided string) (string, error) {
|
||||
uid := strings.TrimSpace(provided)
|
||||
if uid == "" {
|
||||
return shortuuid.New(), nil
|
||||
}
|
||||
if !base.UIDMatcher.MatchString(uid) {
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid ID format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ import (
|
|||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
placeholders := []string{"?", "?", "?", "?", "?"}
|
||||
fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
|
|
@ -35,8 +35,11 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
|
|||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "`uid` = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
|
||||
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `uid`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -50,6 +53,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
|
|||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.UID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ import (
|
|||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
fields := []string{"name", "type", "identifier_filter", "config"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
fields := []string{"uid", "name", "type", "identifier_filter", "config"}
|
||||
args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -25,10 +25,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
|
|||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
uid,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
|
|
@ -48,6 +52,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
|
|||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.UID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
|
|
@ -83,7 +88,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
|
|||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
RETURNING id, uid, name, type, identifier_filter, config
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
|
||||
|
|
@ -91,6 +96,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
|
|||
var typeString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.UID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import (
|
|||
)
|
||||
|
||||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
|
||||
placeholders := []string{"?", "?", "?", "?"}
|
||||
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
placeholders := []string{"?", "?", "?", "?", "?"}
|
||||
fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"}
|
||||
args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
|
||||
|
||||
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
|
|
@ -28,10 +28,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
|
|||
if v := find.ID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.UID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("uid = $%d", len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
uid,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
|
|
@ -51,6 +55,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
|
|||
var typeString string
|
||||
if err := rows.Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.UID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
|
|
@ -86,12 +91,13 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
|
|||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
RETURNING id, uid, name, type, identifier_filter, config
|
||||
`
|
||||
var identityProvider store.IdentityProvider
|
||||
var typeString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.UID,
|
||||
&identityProvider.Name,
|
||||
&typeString,
|
||||
&identityProvider.IdentifierFilter,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
type IdentityProvider struct {
|
||||
ID int32
|
||||
UID string
|
||||
Name string
|
||||
Type storepb.IdentityProvider_Type
|
||||
IdentifierFilter string
|
||||
|
|
@ -18,7 +19,8 @@ type IdentityProvider struct {
|
|||
}
|
||||
|
||||
type FindIdentityProvider struct {
|
||||
ID *int32
|
||||
ID *int32
|
||||
UID *string
|
||||
}
|
||||
|
||||
type UpdateIdentityProvider struct {
|
||||
|
|
@ -130,6 +132,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
|
|||
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
|
||||
identityProvider := &storepb.IdentityProvider{
|
||||
Id: raw.ID,
|
||||
Uid: raw.UID,
|
||||
Name: raw.Name,
|
||||
Type: raw.Type,
|
||||
IdentifierFilter: raw.IdentifierFilter,
|
||||
|
|
@ -145,6 +148,7 @@ func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityPro
|
|||
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
|
||||
raw := &IdentityProvider{
|
||||
ID: identityProvider.Id,
|
||||
UID: identityProvider.Uid,
|
||||
Name: identityProvider.Name,
|
||||
Type: identityProvider.Type,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
-- Add uid column to idp table
|
||||
ALTER TABLE `idp` ADD COLUMN `uid` VARCHAR(256) NOT NULL DEFAULT '';
|
||||
|
||||
-- Populate uid for existing rows using hex of id as a fallback
|
||||
UPDATE `idp` SET `uid` = LOWER(LPAD(HEX(`id`), 8, '0')) WHERE `uid` = '';
|
||||
|
||||
-- Create unique index on uid
|
||||
ALTER TABLE `idp` ADD UNIQUE INDEX `idx_idp_uid` (`uid`);
|
||||
|
|
@ -80,6 +80,7 @@ CREATE TABLE `activity` (
|
|||
-- idp
|
||||
CREATE TABLE `idp` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
`uid` VARCHAR(256) NOT NULL UNIQUE,
|
||||
`name` TEXT NOT NULL,
|
||||
`type` TEXT NOT NULL,
|
||||
`identifier_filter` VARCHAR(256) NOT NULL DEFAULT '',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
-- Add uid column to idp table
|
||||
ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT '';
|
||||
|
||||
-- Populate uid for existing rows using hex of id as a fallback
|
||||
UPDATE idp SET uid = LPAD(TO_HEX(id), 8, '0') WHERE uid = '';
|
||||
|
||||
-- Create unique index on uid
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid);
|
||||
|
|
@ -80,6 +80,7 @@ CREATE TABLE activity (
|
|||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id SERIAL PRIMARY KEY,
|
||||
uid TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter TEXT NOT NULL DEFAULT '',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
-- Add uid column to idp table
|
||||
ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT '';
|
||||
|
||||
-- Populate uid for existing rows using hex of id as a fallback
|
||||
UPDATE idp SET uid = printf('%08x', id) WHERE uid = '';
|
||||
|
||||
-- Create unique index on uid
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid);
|
||||
|
|
@ -81,6 +81,7 @@ CREATE TABLE activity (
|
|||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
uid TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter TEXT NOT NULL DEFAULT '',
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ func TestIdentityProviderStore(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Uid: "test-github-oauth",
|
||||
Name: "GitHub OAuth",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: "",
|
||||
|
|
@ -37,6 +38,7 @@ func TestIdentityProviderStore(t *testing.T) {
|
|||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test-github-oauth", createdIDP.Uid)
|
||||
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &createdIDP.Id,
|
||||
})
|
||||
|
|
@ -66,7 +68,7 @@ func TestIdentityProviderGetByID(t *testing.T) {
|
|||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create IDP
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by ID
|
||||
|
|
@ -76,6 +78,13 @@ func TestIdentityProviderGetByID(t *testing.T) {
|
|||
require.Equal(t, idp.Id, found.Id)
|
||||
require.Equal(t, idp.Name, found.Name)
|
||||
|
||||
// Get by UID
|
||||
foundByUID, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &idp.Uid})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, foundByUID)
|
||||
require.Equal(t, idp.Id, foundByUID.Id)
|
||||
require.Equal(t, idp.Uid, foundByUID.Uid)
|
||||
|
||||
// Get by non-existent ID
|
||||
nonExistentID := int32(99999)
|
||||
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
|
||||
|
|
@ -91,11 +100,11 @@ func TestIdentityProviderListMultiple(t *testing.T) {
|
|||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
|
||||
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth"))
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth", "gitlab-oauth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all
|
||||
|
|
@ -112,9 +121,9 @@ func TestIdentityProviderListByID(t *testing.T) {
|
|||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth"))
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
|
||||
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// List by specific ID
|
||||
|
|
@ -131,7 +140,7 @@ func TestIdentityProviderUpdateName(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name"))
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name", "original-name"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Original Name", idp.Name)
|
||||
|
||||
|
|
@ -158,7 +167,7 @@ func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", idp.IdentifierFilter)
|
||||
|
||||
|
|
@ -185,7 +194,7 @@ func TestIdentityProviderUpdateConfig(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update config
|
||||
|
|
@ -229,7 +238,7 @@ func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original"))
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original", "original"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update multiple fields at once
|
||||
|
|
@ -253,7 +262,7 @@ func TestIdentityProviderDelete(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
|
||||
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete
|
||||
|
|
@ -274,9 +283,9 @@ func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
|
|||
ts := NewTestingStore(ctx, t)
|
||||
|
||||
// Create multiple IDPs
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1"))
|
||||
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1", "idp-1"))
|
||||
require.NoError(t, err)
|
||||
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2"))
|
||||
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2", "idp-2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete first one
|
||||
|
|
@ -304,6 +313,7 @@ func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
|
|||
|
||||
// Create IDP with multiple scopes
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Uid: "multi-scope-oauth",
|
||||
Name: "Multi-Scope OAuth",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
|
|
@ -343,6 +353,7 @@ func TestIdentityProviderFieldMapping(t *testing.T) {
|
|||
|
||||
// Create IDP with custom field mapping
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Uid: "custom-field-mapping",
|
||||
Name: "Custom Field Mapping",
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
Config: &storepb.IdentityProviderConfig{
|
||||
|
|
@ -382,17 +393,19 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
|
|||
|
||||
testCases := []struct {
|
||||
name string
|
||||
uid string
|
||||
filter string
|
||||
}{
|
||||
{"Domain filter", "@company\\.com$"},
|
||||
{"Prefix filter", "^admin_"},
|
||||
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
|
||||
{"Empty filter", ""},
|
||||
{"Domain filter", "domain-filter", "@company\\.com$"},
|
||||
{"Prefix filter", "prefix-filter", "^admin_"},
|
||||
{"Complex regex", "complex-regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
|
||||
{"Empty filter", "empty-filter", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
|
||||
Uid: tc.uid,
|
||||
Name: tc.name,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: tc.filter,
|
||||
|
|
@ -428,8 +441,9 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
|
|||
}
|
||||
|
||||
// Helper function to create a test OAuth2 IDP.
|
||||
func createTestOAuth2IDP(name string) *storepb.IdentityProvider {
|
||||
func createTestOAuth2IDP(name, uid string) *storepb.IdentityProvider {
|
||||
return &storepb.IdentityProvider{
|
||||
Uid: uid,
|
||||
Name: name,
|
||||
Type: storepb.IdentityProvider_OAUTH2,
|
||||
IdentifierFilter: "",
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
const identityProviderTypes = [...new Set(templateList.map((t) => t.type))];
|
||||
const [basicInfo, setBasicInfo] = useState({
|
||||
title: "",
|
||||
identifier: "",
|
||||
identifierFilter: "",
|
||||
});
|
||||
const [type, setType] = useState<IdentityProvider_Type>(IdentityProvider_Type.OAUTH2);
|
||||
|
|
@ -161,6 +162,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
// Reset to default state when dialog is closed
|
||||
setBasicInfo({
|
||||
title: "",
|
||||
identifier: "",
|
||||
identifierFilter: "",
|
||||
});
|
||||
setType(IdentityProvider_Type.OAUTH2);
|
||||
|
|
@ -189,6 +191,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
if (open && identityProvider) {
|
||||
setBasicInfo({
|
||||
title: identityProvider.title,
|
||||
identifier: "",
|
||||
identifierFilter: identityProvider.identifierFilter,
|
||||
});
|
||||
setType(identityProvider.type);
|
||||
|
|
@ -210,6 +213,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
if (template) {
|
||||
setBasicInfo({
|
||||
title: template.title,
|
||||
identifier: template.title.toLowerCase().replace(/[^a-z0-9]+/g, "-"),
|
||||
identifierFilter: template.identifierFilter,
|
||||
});
|
||||
setType(template.type);
|
||||
|
|
@ -229,6 +233,9 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
if (basicInfo.title === "") {
|
||||
return false;
|
||||
}
|
||||
if (isCreating && basicInfo.identifier === "") {
|
||||
return false;
|
||||
}
|
||||
if (type === IdentityProvider_Type.OAUTH2) {
|
||||
if (
|
||||
oauth2Config.clientId === "" ||
|
||||
|
|
@ -254,8 +261,10 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
try {
|
||||
if (isCreating) {
|
||||
await identityProviderServiceClient.createIdentityProvider({
|
||||
identityProviderId: basicInfo.identifier,
|
||||
identityProvider: create(IdentityProviderSchema, {
|
||||
...basicInfo,
|
||||
title: basicInfo.title,
|
||||
identifierFilter: basicInfo.identifierFilter,
|
||||
type: type,
|
||||
config: create(IdentityProviderConfigSchema, {
|
||||
config: {
|
||||
|
|
@ -343,6 +352,32 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
|
|||
<Separator className="my-2" />
|
||||
</>
|
||||
)}
|
||||
{isCreating && (
|
||||
<>
|
||||
<p className="mb-1 text-sm font-medium">
|
||||
ID
|
||||
<span className="text-destructive">*</span>
|
||||
</p>
|
||||
<Input
|
||||
className="mb-2 w-full font-mono"
|
||||
placeholder="e.g. github, okta-corp"
|
||||
maxLength={32}
|
||||
value={basicInfo.identifier}
|
||||
onChange={(e) =>
|
||||
setBasicInfo({
|
||||
...basicInfo,
|
||||
identifier: e.target.value
|
||||
.toLowerCase()
|
||||
.replace(/[^a-z0-9-]/g, "-")
|
||||
.replace(/--+/g, "-"),
|
||||
})
|
||||
}
|
||||
/>
|
||||
<p className="mb-2 text-xs text-muted-foreground">
|
||||
A unique identifier for this provider. Lowercase letters, numbers, and hyphens only.
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
<p className="mb-1 text-sm font-medium">
|
||||
{t("common.name")}
|
||||
<span className="text-destructive">*</span>
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ export const extractMemoIdFromName = (name: string) => {
|
|||
return name.split(memoNamePrefix).pop() || "";
|
||||
};
|
||||
|
||||
export const extractIdentityProviderIdFromName = (name: string) => {
|
||||
return parseInt(name.split(identityProviderNamePrefix).pop() || "", 10);
|
||||
export const extractIdentityProviderUidFromName = (name: string) => {
|
||||
return name.split(identityProviderNamePrefix).pop() || "";
|
||||
};
|
||||
|
||||
// Helper function to convert InstanceSetting_Key enum value to string name
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ const AuthCallback = () => {
|
|||
return;
|
||||
}
|
||||
|
||||
const { identityProviderId, returnUrl, codeVerifier } = validatedState;
|
||||
const { identityProviderName, returnUrl, codeVerifier } = validatedState;
|
||||
const redirectUri = absolutifyLink("/auth/callback");
|
||||
|
||||
(async () => {
|
||||
|
|
@ -81,7 +81,7 @@ const AuthCallback = () => {
|
|||
credentials: {
|
||||
case: "ssoCredentials",
|
||||
value: {
|
||||
idpId: identityProviderId,
|
||||
idpName: identityProviderName,
|
||||
code,
|
||||
redirectUri,
|
||||
codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import { Button } from "@/components/ui/button";
|
|||
import { Separator } from "@/components/ui/separator";
|
||||
import { identityProviderServiceClient } from "@/connect";
|
||||
import { useInstance } from "@/contexts/InstanceContext";
|
||||
import { extractIdentityProviderIdFromName } from "@/helpers/resource-names";
|
||||
import { absolutifyLink } from "@/helpers/utils";
|
||||
import useCurrentUser from "@/hooks/useCurrentUser";
|
||||
import { handleError } from "@/lib/error";
|
||||
|
|
@ -50,8 +49,7 @@ const SignIn = () => {
|
|||
try {
|
||||
// Generate and store secure state parameter with CSRF protection
|
||||
// Also generate PKCE parameters (code_challenge) for enhanced security if available
|
||||
const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name);
|
||||
const { state, codeChallenge } = await storeOAuthState(identityProviderId);
|
||||
const { state, codeChallenge } = await storeOAuthState(identityProvider.name);
|
||||
|
||||
// Build OAuth authorization URL with secure state
|
||||
// Include PKCE if available (requires HTTPS/localhost for crypto.subtle)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import type { Message } from "@bufbuild/protobuf";
|
|||
* Describes the file api/v1/auth_service.proto.
|
||||
*/
|
||||
export const file_api_v1_auth_service: GenFile = /*@__PURE__*/
|
||||
fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIuwCCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIabwoOU1NPQ3JlZGVudGlhbHMSEwoGaWRwX2lkGAEgASgFQgPgQQISEQoEY29kZRgCIAEoCUID4EECEhkKDHJlZGlyZWN0X3VyaRgDIAEoCUID4EECEhoKDWNvZGVfdmVyaWZpZXIYBCABKAlCA+BBAUINCgtjcmVkZW50aWFscyKFAQoOU2lnbkluUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyEhQKDGFjY2Vzc190b2tlbhgCIAEoCRI7ChdhY2Nlc3NfdG9rZW5fZXhwaXJlc19hdBgDIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAiEAoOU2lnbk91dFJlcXVlc3QiFQoTUmVmcmVzaFRva2VuUmVxdWVzdCJcChRSZWZyZXNoVG9rZW5SZXNwb25zZRIUCgxhY2Nlc3NfdG9rZW4YASABKAkSLgoKZXhwaXJlc19hdBgCIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAyvwMKC0F1dGhTZXJ2aWNlEnQKDkdldEN1cnJlbnRVc2VyEiMubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVxdWVzdBokLm1lbW9zLmFwaS52MS5HZXRDdXJyZW50VXNlclJlc3BvbnNlIheC0+STAhESDy9hcGkvdjEvYXV0aC9tZRJjCgZTaWduSW4SGy5tZW1vcy5hcGkudjEuU2lnbkluUmVxdWVzdBocLm1lbW9zLmFwaS52MS5TaWduSW5SZXNwb25zZSIegtPkkwIYOgEqIhMvYXBpL3YxL2F1dGgvc2lnbmluEl0KB1NpZ25PdXQSHC5tZW1vcy5hcGkudjEuU2lnbk91dFJlcXVlc3QaFi5nb29nbGUucHJvdG9idWYuRW1wdHkiHILT5JMCFiIUL2FwaS92MS9hdXRoL3NpZ25vdXQSdgoMUmVmcmVzaFRva2VuEiEubWVtb3MuYXBpLnYxLlJlZnJlc2hUb2tlblJlcXVlc3QaIi5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVzcG9uc2UiH4LT5JMCGToBKiIUL2FwaS92MS9hdXRoL3JlZnJlc2hCqAEKEGNvbS5tZW1vcy5hcGkudjFCEEF1dGhTZXJ2aWNlUHJvdG9QAVowZ2l0aHViLmNvbS91c2VtZW1vcy9tZW1vcy9wcm90by9nZW4vYXBpL3YxO2FwaXYxogIDTUFYqgIMTWVtb3MuQXBpLlYxygIMTWVtb3NcQXBpXFYx4gIYTWVtb3NcQXBpXFYxXEdQQk1ldGFkYXRh6gIOTWVtb3M6OkFwaTo6VjFiBnByb3RvMw", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]);
|
||||
fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIu4CCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIacQoOU1NPQ3JlZGVudGlhbHMSFQoIaWRwX25hbWUYASABKAlCA+BBAhIRCgRjb2RlGAIgASgJQgPgQQISGQoMcmVkaXJlY3RfdXJpGAMgASgJQgPgQQISGgoNY29kZV92ZXJpZmllchgEIAEoCUID4EEBQg0KC2NyZWRlbnRpYWxzIoUBCg5TaWduSW5SZXNwb25zZRIgCgR1c2VyGAEgASgLMhIubWVtb3MuYXBpLnYxLlVzZXISFAoMYWNjZXNzX3Rva2VuGAIgASgJEjsKF2FjY2Vzc190b2tlbl9leHBpcmVzX2F0GAMgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcCIQCg5TaWduT3V0UmVxdWVzdCIVChNSZWZyZXNoVG9rZW5SZXF1ZXN0IlwKFFJlZnJlc2hUb2tlblJlc3BvbnNlEhQKDGFjY2Vzc190b2tlbhgBIAEoCRIuCgpleHBpcmVzX2F0GAIgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcDK/AwoLQXV0aFNlcnZpY2USdAoOR2V0Q3VycmVudFVzZXISIy5tZW1vcy5hcGkudjEuR2V0Q3VycmVudFVzZXJSZXF1ZXN0GiQubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVzcG9uc2UiF4LT5JMCERIPL2FwaS92MS9hdXRoL21lEmMKBlNpZ25JbhIbLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0GhwubWVtb3MuYXBpLnYxLlNpZ25JblJlc3BvbnNlIh6C0+STAhg6ASoiEy9hcGkvdjEvYXV0aC9zaWduaW4SXQoHU2lnbk91dBIcLm1lbW9zLmFwaS52MS5TaWduT3V0UmVxdWVzdBoWLmdvb2dsZS5wcm90b2J1Zi5FbXB0eSIcgtPkkwIWIhQvYXBpL3YxL2F1dGgvc2lnbm91dBJ2CgxSZWZyZXNoVG9rZW4SIS5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVxdWVzdBoiLm1lbW9zLmFwaS52MS5SZWZyZXNoVG9rZW5SZXNwb25zZSIfgtPkkwIZOgEqIhQvYXBpL3YxL2F1dGgvcmVmcmVzaEKoAQoQY29tLm1lbW9zLmFwaS52MUIQQXV0aFNlcnZpY2VQcm90b1ABWjBnaXRodWIuY29tL3VzZW1lbW9zL21lbW9zL3Byb3RvL2dlbi9hcGkvdjE7YXBpdjGiAgNNQViqAgxNZW1vcy5BcGkuVjHKAgxNZW1vc1xBcGlcVjHiAhhNZW1vc1xBcGlcVjFcR1BCTWV0YWRhdGHqAg5NZW1vczo6QXBpOjpWMWIGcHJvdG8z", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]);
|
||||
|
||||
/**
|
||||
* @generated from message memos.api.v1.GetCurrentUserRequest
|
||||
|
|
@ -120,11 +120,12 @@ export const SignInRequest_PasswordCredentialsSchema: GenMessage<SignInRequest_P
|
|||
*/
|
||||
export type SignInRequest_SSOCredentials = Message<"memos.api.v1.SignInRequest.SSOCredentials"> & {
|
||||
/**
|
||||
* The ID of the SSO provider.
|
||||
* The resource name of the SSO provider.
|
||||
* Format: identity-providers/{uid}
|
||||
*
|
||||
* @generated from field: int32 idp_id = 1;
|
||||
* @generated from field: string idp_name = 1;
|
||||
*/
|
||||
idpId: number;
|
||||
idpName: string;
|
||||
|
||||
/**
|
||||
* The authorization code from the SSO provider.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes
|
|||
|
||||
interface OAuthState {
|
||||
state: string;
|
||||
identityProviderId: number;
|
||||
identityProviderName: string;
|
||||
timestamp: number;
|
||||
returnUrl?: string;
|
||||
codeVerifier?: string; // PKCE code_verifier
|
||||
|
|
@ -42,7 +42,10 @@ function base64UrlEncode(buffer: Uint8Array): string {
|
|||
// Store OAuth state and PKCE parameters in sessionStorage
|
||||
// Returns state and optional codeChallenge for use in authorization URL
|
||||
// PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth
|
||||
export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge?: string }> {
|
||||
export async function storeOAuthState(
|
||||
identityProviderName: string,
|
||||
returnUrl?: string,
|
||||
): Promise<{ state: string; codeChallenge?: string }> {
|
||||
const state = generateSecureState();
|
||||
|
||||
// Try to generate PKCE parameters if crypto.subtle is available (HTTPS/localhost)
|
||||
|
|
@ -70,7 +73,7 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st
|
|||
|
||||
const stateData: OAuthState = {
|
||||
state,
|
||||
identityProviderId,
|
||||
identityProviderName,
|
||||
timestamp: Date.now(),
|
||||
returnUrl,
|
||||
codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available)
|
||||
|
|
@ -87,8 +90,8 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st
|
|||
}
|
||||
|
||||
// Validate and retrieve OAuth state from storage (CSRF protection)
|
||||
// Returns identityProviderId, returnUrl, and codeVerifier for PKCE
|
||||
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null {
|
||||
// Returns identityProviderName, returnUrl, and codeVerifier for PKCE
|
||||
export function validateOAuthState(stateParam: string): { identityProviderName: string; returnUrl?: string; codeVerifier?: string } | null {
|
||||
try {
|
||||
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
|
||||
if (!storedData) {
|
||||
|
|
@ -115,7 +118,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu
|
|||
// State is valid, clean up and return data
|
||||
sessionStorage.removeItem(STATE_STORAGE_KEY);
|
||||
return {
|
||||
identityProviderId: stateData.identityProviderId,
|
||||
identityProviderName: stateData.identityProviderName,
|
||||
returnUrl: stateData.returnUrl,
|
||||
codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue