feat: replace auto-increment ID with UID for identity provider resource names (#5687)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
memoclaw 2026-03-05 21:01:22 +08:00 committed by GitHub
parent f0c4489468
commit 92d937b1aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 245 additions and 109 deletions

View File

@ -64,8 +64,9 @@ message SignInRequest {
// Nested message for SSO authentication credentials.
message SSOCredentials {
// The ID of the SSO provider.
int32 idp_id = 1 [(google.api.field_behavior) = REQUIRED];
// The resource name of the SSO provider.
// Format: identity-providers/{uid}
string idp_name = 1 [(google.api.field_behavior) = REQUIRED];
// The authorization code from the SSO provider.
string code = 2 [(google.api.field_behavior) = REQUIRED];

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/activity_service.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/attachment_service.proto

View File

@ -440,8 +440,9 @@ func (x *SignInRequest_PasswordCredentials) GetPassword() string {
// Nested message for SSO authentication credentials.
type SignInRequest_SSOCredentials struct {
state protoimpl.MessageState `protogen:"open.v1"`
// The ID of the SSO provider.
IdpId int32 `protobuf:"varint,1,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"`
// The resource name of the SSO provider.
// Format: identity-providers/{uid}
IdpName string `protobuf:"bytes,1,opt,name=idp_name,json=idpName,proto3" json:"idp_name,omitempty"`
// The authorization code from the SSO provider.
Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"`
// The redirect URI used in the SSO flow.
@ -483,11 +484,11 @@ func (*SignInRequest_SSOCredentials) Descriptor() ([]byte, []int) {
return file_api_v1_auth_service_proto_rawDescGZIP(), []int{2, 1}
}
func (x *SignInRequest_SSOCredentials) GetIdpId() int32 {
func (x *SignInRequest_SSOCredentials) GetIdpName() string {
if x != nil {
return x.IdpId
return x.IdpName
}
return 0
return ""
}
func (x *SignInRequest_SSOCredentials) GetCode() string {
@ -518,15 +519,15 @@ const file_api_v1_auth_service_proto_rawDesc = "" +
"\x19api/v1/auth_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\x17\n" +
"\x15GetCurrentUserRequest\"@\n" +
"\x16GetCurrentUserResponse\x12&\n" +
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xce\x03\n" +
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xd2\x03\n" +
"\rSignInRequest\x12d\n" +
"\x14password_credentials\x18\x01 \x01(\v2/.memos.api.v1.SignInRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12U\n" +
"\x0fsso_credentials\x18\x02 \x01(\v2*.memos.api.v1.SignInRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" +
"\x13PasswordCredentials\x12\x1f\n" +
"\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" +
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" +
"\x0eSSOCredentials\x12\x1a\n" +
"\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" +
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x9b\x01\n" +
"\x0eSSOCredentials\x12\x1e\n" +
"\bidp_name\x18\x01 \x01(\tB\x03\xe0A\x02R\aidpName\x12\x17\n" +
"\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" +
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" +
"\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" +

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/auth_service.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/idp_service.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/instance_service.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/memo_service.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/shortcut_service.proto

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/user_service.proto

View File

@ -2757,15 +2757,16 @@ components:
description: Nested message for password-based authentication credentials.
SignInRequest_SSOCredentials:
required:
- idpId
- idpName
- code
- redirectUri
type: object
properties:
idpId:
type: integer
description: The ID of the SSO provider.
format: int32
idpName:
type: string
description: |-
The resource name of the SSO provider.
Format: identity-providers/{uid}
code:
type: string
description: The authorization code from the SSO provider.

View File

@ -74,6 +74,7 @@ type IdentityProvider struct {
Type IdentityProvider_Type `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.IdentityProvider_Type" json:"type,omitempty"`
IdentifierFilter string `protobuf:"bytes,4,opt,name=identifier_filter,json=identifierFilter,proto3" json:"identifier_filter,omitempty"`
Config *IdentityProviderConfig `protobuf:"bytes,5,opt,name=config,proto3" json:"config,omitempty"`
Uid string `protobuf:"bytes,6,opt,name=uid,proto3" json:"uid,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -143,6 +144,13 @@ func (x *IdentityProvider) GetConfig() *IdentityProviderConfig {
return nil
}
func (x *IdentityProvider) GetUid() string {
if x != nil {
return x.Uid
}
return ""
}
type IdentityProviderConfig struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Config:
@ -373,13 +381,14 @@ var File_store_idp_proto protoreflect.FileDescriptor
const file_store_idp_proto_rawDesc = "" +
"\n" +
"\x0fstore/idp.proto\x12\vmemos.store\"\x82\x02\n" +
"\x0fstore/idp.proto\x12\vmemos.store\"\x94\x02\n" +
"\x10IdentityProvider\x12\x0e\n" +
"\x02id\x18\x01 \x01(\x05R\x02id\x12\x12\n" +
"\x04name\x18\x02 \x01(\tR\x04name\x126\n" +
"\x04type\x18\x03 \x01(\x0e2\".memos.store.IdentityProvider.TypeR\x04type\x12+\n" +
"\x11identifier_filter\x18\x04 \x01(\tR\x10identifierFilter\x12;\n" +
"\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\"(\n" +
"\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\x12\x10\n" +
"\x03uid\x18\x06 \x01(\tR\x03uid\"(\n" +
"\x04Type\x12\x14\n" +
"\x10TYPE_UNSPECIFIED\x10\x00\x12\n" +
"\n" +

View File

@ -15,6 +15,7 @@ message IdentityProvider {
Type type = 3;
string identifier_filter = 4;
IdentityProviderConfig config = 5;
string uid = 6;
}
message IdentityProviderConfig {

View File

@ -16,7 +16,6 @@ import (
"time"
"github.com/disintegration/imaging"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -100,10 +99,9 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
}
// Use provided attachment_id or generate a new one
attachmentUID := request.AttachmentId
if attachmentUID == "" {
attachmentUID = shortuuid.New()
attachmentUID, err := ValidateAndGenerateUID(request.AttachmentId)
if err != nil {
return nil, err
}
create := &store.Attachment{

View File

@ -90,8 +90,12 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
existingUser = user
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
// Authentication Method 2: SSO (OAuth2) authentication
idpUID, err := ExtractIdentityProviderUIDFromName(ssoCredentials.IdpName)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &ssoCredentials.IdpId,
UID: &idpUID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)

View File

@ -25,7 +25,15 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
idpUID, err := ValidateAndGenerateUID(request.IdentityProviderId)
if err != nil {
return nil, err
}
storeIdp := convertIdentityProviderToStore(request.IdentityProvider)
storeIdp.Uid = idpUID
identityProvider, err := s.Store.CreateIdentityProvider(ctx, storeIdp)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
@ -57,12 +65,12 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
}
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
uid, err := ExtractIdentityProviderUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &id,
UID: &uid,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
@ -98,12 +106,22 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
uid, err := ExtractIdentityProviderUIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
// Look up the IdP by UID to get the internal ID for update.
existing, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if existing == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
update := &store.UpdateIdentityProviderV1{
ID: id,
ID: existing.Id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
@ -138,13 +156,13 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
id, err := ExtractIdentityProviderIDFromName(request.Name)
uid, err := ExtractIdentityProviderUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
// Check if the identity provider exists before trying to delete it
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id})
// Look up the IdP by UID to get the internal ID for deletion.
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
}
@ -152,7 +170,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProvider.Id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
return &emptypb.Empty{}, nil
@ -160,7 +178,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
temp := &v1pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Name: fmt.Sprintf("%s%s", IdentityProviderNamePrefix, identityProvider.Uid),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
@ -190,10 +208,7 @@ func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider
}
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),

View File

@ -7,13 +7,11 @@ import (
"strings"
"time"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
@ -30,13 +28,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
// Use custom memo_id if provided, otherwise generate a new UUID
memoUID := strings.TrimSpace(request.MemoId)
if memoUID == "" {
memoUID = shortuuid.New()
} else if !base.UIDMatcher.MatchString(memoUID) {
// Validate custom memo ID format
return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
memoUID, err := ValidateAndGenerateUID(request.MemoId)
if err != nil {
return nil, err
}
create := &store.Memo{

View File

@ -4,8 +4,12 @@ import (
"fmt"
"strings"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/internal/util"
)
@ -133,16 +137,12 @@ func ExtractInboxIDFromName(name string) (int32, error) {
return id, nil
}
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
func ExtractIdentityProviderUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil {
return 0, err
return "", err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
}
return id, nil
return tokens[0], nil
}
func ExtractActivityIDFromName(name string) (int32, error) {
@ -156,3 +156,17 @@ func ExtractActivityIDFromName(name string) (int32, error) {
}
return id, nil
}
// ValidateAndGenerateUID validates a user-provided UID or generates a new one.
// If provided is empty, a new shortuuid is generated.
// If provided is non-empty, it is validated against base.UIDMatcher.
func ValidateAndGenerateUID(provided string) (string, error) {
uid := strings.TrimSpace(provided)
if uid == "" {
return shortuuid.New(), nil
}
if !base.UIDMatcher.MatchString(uid) {
return "", status.Errorf(codes.InvalidArgument, "invalid ID format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
}
return uid, nil
}

View File

@ -11,9 +11,9 @@ import (
)
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
placeholders := []string{"?", "?", "?", "?"}
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
placeholders := []string{"?", "?", "?", "?", "?"}
fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
result, err := d.db.ExecContext(ctx, stmt, args...)
@ -35,8 +35,11 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
if v := find.ID; v != nil {
where, args = append(where, "`id` = ?"), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "`uid` = ?"), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `uid`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
args...,
)
if err != nil {
@ -50,6 +53,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var typeString string
if err := rows.Scan(
&identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,

View File

@ -9,8 +9,8 @@ import (
)
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
fields := []string{"name", "type", "identifier_filter", "config"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
fields := []string{"uid", "name", "type", "identifier_filter", "config"}
args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
return nil, err
@ -25,10 +25,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
if v := find.ID; v != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
uid,
name,
type,
identifier_filter,
@ -48,6 +52,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var typeString string
if err := rows.Scan(
&identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,
@ -83,7 +88,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, name, type, identifier_filter, config
RETURNING id, uid, name, type, identifier_filter, config
`
args = append(args, update.ID)
@ -91,6 +96,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,

View File

@ -10,9 +10,9 @@ import (
)
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
placeholders := []string{"?", "?", "?", "?"}
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
placeholders := []string{"?", "?", "?", "?", "?"}
fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
@ -28,10 +28,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
if v := find.ID; v != nil {
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
}
if v := find.UID; v != nil {
where, args = append(where, fmt.Sprintf("uid = $%d", len(args)+1)), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
uid,
name,
type,
identifier_filter,
@ -51,6 +55,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var typeString string
if err := rows.Scan(
&identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,
@ -86,12 +91,13 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, name, type, identifier_filter, config
RETURNING id, uid, name, type, identifier_filter, config
`
var identityProvider store.IdentityProvider
var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name,
&typeString,
&identityProvider.IdentifierFilter,

View File

@ -11,6 +11,7 @@ import (
type IdentityProvider struct {
ID int32
UID string
Name string
Type storepb.IdentityProvider_Type
IdentifierFilter string
@ -18,7 +19,8 @@ type IdentityProvider struct {
}
type FindIdentityProvider struct {
ID *int32
ID *int32
UID *string
}
type UpdateIdentityProvider struct {
@ -130,6 +132,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
identityProvider := &storepb.IdentityProvider{
Id: raw.ID,
Uid: raw.UID,
Name: raw.Name,
Type: raw.Type,
IdentifierFilter: raw.IdentifierFilter,
@ -145,6 +148,7 @@ func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityPro
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
raw := &IdentityProvider{
ID: identityProvider.Id,
UID: identityProvider.Uid,
Name: identityProvider.Name,
Type: identityProvider.Type,
IdentifierFilter: identityProvider.IdentifierFilter,

View File

@ -0,0 +1,8 @@
-- Add uid column to idp table
ALTER TABLE `idp` ADD COLUMN `uid` VARCHAR(256) NOT NULL DEFAULT '';
-- Populate uid for existing rows using hex of id as a fallback
UPDATE `idp` SET `uid` = LOWER(LPAD(HEX(`id`), 8, '0')) WHERE `uid` = '';
-- Create unique index on uid
ALTER TABLE `idp` ADD UNIQUE INDEX `idx_idp_uid` (`uid`);

View File

@ -80,6 +80,7 @@ CREATE TABLE `activity` (
-- idp
CREATE TABLE `idp` (
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
`uid` VARCHAR(256) NOT NULL UNIQUE,
`name` TEXT NOT NULL,
`type` TEXT NOT NULL,
`identifier_filter` VARCHAR(256) NOT NULL DEFAULT '',

View File

@ -0,0 +1,8 @@
-- Add uid column to idp table
ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT '';
-- Populate uid for existing rows using hex of id as a fallback
UPDATE idp SET uid = LPAD(TO_HEX(id), 8, '0') WHERE uid = '';
-- Create unique index on uid
CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid);

View File

@ -80,6 +80,7 @@ CREATE TABLE activity (
-- idp
CREATE TABLE idp (
id SERIAL PRIMARY KEY,
uid TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
type TEXT NOT NULL,
identifier_filter TEXT NOT NULL DEFAULT '',

View File

@ -0,0 +1,8 @@
-- Add uid column to idp table
ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT '';
-- Populate uid for existing rows using hex of id as a fallback
UPDATE idp SET uid = printf('%08x', id) WHERE uid = '';
-- Create unique index on uid
CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid);

View File

@ -81,6 +81,7 @@ CREATE TABLE activity (
-- idp
CREATE TABLE idp (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uid TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
type TEXT NOT NULL,
identifier_filter TEXT NOT NULL DEFAULT '',

View File

@ -15,6 +15,7 @@ func TestIdentityProviderStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: "test-github-oauth",
Name: "GitHub OAuth",
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "",
@ -37,6 +38,7 @@ func TestIdentityProviderStore(t *testing.T) {
},
})
require.NoError(t, err)
require.Equal(t, "test-github-oauth", createdIDP.Uid)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &createdIDP.Id,
})
@ -66,7 +68,7 @@ func TestIdentityProviderGetByID(t *testing.T) {
ts := NewTestingStore(ctx, t)
// Create IDP
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err)
// Get by ID
@ -76,6 +78,13 @@ func TestIdentityProviderGetByID(t *testing.T) {
require.Equal(t, idp.Id, found.Id)
require.Equal(t, idp.Name, found.Name)
// Get by UID
foundByUID, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &idp.Uid})
require.NoError(t, err)
require.NotNil(t, foundByUID)
require.Equal(t, idp.Id, foundByUID.Id)
require.Equal(t, idp.Uid, foundByUID.Uid)
// Get by non-existent ID
nonExistentID := int32(99999)
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
@ -91,11 +100,11 @@ func TestIdentityProviderListMultiple(t *testing.T) {
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth"))
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth", "gitlab-oauth"))
require.NoError(t, err)
// List all
@ -112,9 +121,9 @@ func TestIdentityProviderListByID(t *testing.T) {
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth"))
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth"))
require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth"))
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth"))
require.NoError(t, err)
// List by specific ID
@ -131,7 +140,7 @@ func TestIdentityProviderUpdateName(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name"))
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name", "original-name"))
require.NoError(t, err)
require.Equal(t, "Original Name", idp.Name)
@ -158,7 +167,7 @@ func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err)
require.Equal(t, "", idp.IdentifierFilter)
@ -185,7 +194,7 @@ func TestIdentityProviderUpdateConfig(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err)
// Update config
@ -229,7 +238,7 @@ func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original"))
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original", "original"))
require.NoError(t, err)
// Update multiple fields at once
@ -253,7 +262,7 @@ func TestIdentityProviderDelete(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP"))
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err)
// Delete
@ -274,9 +283,9 @@ func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
ts := NewTestingStore(ctx, t)
// Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1"))
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1", "idp-1"))
require.NoError(t, err)
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2"))
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2", "idp-2"))
require.NoError(t, err)
// Delete first one
@ -304,6 +313,7 @@ func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
// Create IDP with multiple scopes
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: "multi-scope-oauth",
Name: "Multi-Scope OAuth",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
@ -343,6 +353,7 @@ func TestIdentityProviderFieldMapping(t *testing.T) {
// Create IDP with custom field mapping
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: "custom-field-mapping",
Name: "Custom Field Mapping",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
@ -382,17 +393,19 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
testCases := []struct {
name string
uid string
filter string
}{
{"Domain filter", "@company\\.com$"},
{"Prefix filter", "^admin_"},
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
{"Empty filter", ""},
{"Domain filter", "domain-filter", "@company\\.com$"},
{"Prefix filter", "prefix-filter", "^admin_"},
{"Complex regex", "complex-regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
{"Empty filter", "empty-filter", ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: tc.uid,
Name: tc.name,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: tc.filter,
@ -428,8 +441,9 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
}
// Helper function to create a test OAuth2 IDP.
func createTestOAuth2IDP(name string) *storepb.IdentityProvider {
func createTestOAuth2IDP(name, uid string) *storepb.IdentityProvider {
return &storepb.IdentityProvider{
Uid: uid,
Name: name,
Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "",

View File

@ -133,6 +133,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
const identityProviderTypes = [...new Set(templateList.map((t) => t.type))];
const [basicInfo, setBasicInfo] = useState({
title: "",
identifier: "",
identifierFilter: "",
});
const [type, setType] = useState<IdentityProvider_Type>(IdentityProvider_Type.OAUTH2);
@ -161,6 +162,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
// Reset to default state when dialog is closed
setBasicInfo({
title: "",
identifier: "",
identifierFilter: "",
});
setType(IdentityProvider_Type.OAUTH2);
@ -189,6 +191,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
if (open && identityProvider) {
setBasicInfo({
title: identityProvider.title,
identifier: "",
identifierFilter: identityProvider.identifierFilter,
});
setType(identityProvider.type);
@ -210,6 +213,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
if (template) {
setBasicInfo({
title: template.title,
identifier: template.title.toLowerCase().replace(/[^a-z0-9]+/g, "-"),
identifierFilter: template.identifierFilter,
});
setType(template.type);
@ -229,6 +233,9 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
if (basicInfo.title === "") {
return false;
}
if (isCreating && basicInfo.identifier === "") {
return false;
}
if (type === IdentityProvider_Type.OAUTH2) {
if (
oauth2Config.clientId === "" ||
@ -254,8 +261,10 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
try {
if (isCreating) {
await identityProviderServiceClient.createIdentityProvider({
identityProviderId: basicInfo.identifier,
identityProvider: create(IdentityProviderSchema, {
...basicInfo,
title: basicInfo.title,
identifierFilter: basicInfo.identifierFilter,
type: type,
config: create(IdentityProviderConfigSchema, {
config: {
@ -343,6 +352,32 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
<Separator className="my-2" />
</>
)}
{isCreating && (
<>
<p className="mb-1 text-sm font-medium">
ID
<span className="text-destructive">*</span>
</p>
<Input
className="mb-2 w-full font-mono"
placeholder="e.g. github, okta-corp"
maxLength={32}
value={basicInfo.identifier}
onChange={(e) =>
setBasicInfo({
...basicInfo,
identifier: e.target.value
.toLowerCase()
.replace(/[^a-z0-9-]/g, "-")
.replace(/--+/g, "-"),
})
}
/>
<p className="mb-2 text-xs text-muted-foreground">
A unique identifier for this provider. Lowercase letters, numbers, and hyphens only.
</p>
</>
)}
<p className="mb-1 text-sm font-medium">
{t("common.name")}
<span className="text-destructive">*</span>

View File

@ -16,8 +16,8 @@ export const extractMemoIdFromName = (name: string) => {
return name.split(memoNamePrefix).pop() || "";
};
export const extractIdentityProviderIdFromName = (name: string) => {
return parseInt(name.split(identityProviderNamePrefix).pop() || "", 10);
export const extractIdentityProviderUidFromName = (name: string) => {
return name.split(identityProviderNamePrefix).pop() || "";
};
// Helper function to convert InstanceSetting_Key enum value to string name

View File

@ -72,7 +72,7 @@ const AuthCallback = () => {
return;
}
const { identityProviderId, returnUrl, codeVerifier } = validatedState;
const { identityProviderName, returnUrl, codeVerifier } = validatedState;
const redirectUri = absolutifyLink("/auth/callback");
(async () => {
@ -81,7 +81,7 @@ const AuthCallback = () => {
credentials: {
case: "ssoCredentials",
value: {
idpId: identityProviderId,
idpName: identityProviderName,
code,
redirectUri,
codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange

View File

@ -7,7 +7,6 @@ import { Button } from "@/components/ui/button";
import { Separator } from "@/components/ui/separator";
import { identityProviderServiceClient } from "@/connect";
import { useInstance } from "@/contexts/InstanceContext";
import { extractIdentityProviderIdFromName } from "@/helpers/resource-names";
import { absolutifyLink } from "@/helpers/utils";
import useCurrentUser from "@/hooks/useCurrentUser";
import { handleError } from "@/lib/error";
@ -50,8 +49,7 @@ const SignIn = () => {
try {
// Generate and store secure state parameter with CSRF protection
// Also generate PKCE parameters (code_challenge) for enhanced security if available
const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name);
const { state, codeChallenge } = await storeOAuthState(identityProviderId);
const { state, codeChallenge } = await storeOAuthState(identityProvider.name);
// Build OAuth authorization URL with secure state
// Include PKCE if available (requires HTTPS/localhost for crypto.subtle)

View File

@ -16,7 +16,7 @@ import type { Message } from "@bufbuild/protobuf";
* Describes the file api/v1/auth_service.proto.
*/
export const file_api_v1_auth_service: GenFile = /*@__PURE__*/
fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIuwCCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIabwoOU1NPQ3JlZGVudGlhbHMSEwoGaWRwX2lkGAEgASgFQgPgQQISEQoEY29kZRgCIAEoCUID4EECEhkKDHJlZGlyZWN0X3VyaRgDIAEoCUID4EECEhoKDWNvZGVfdmVyaWZpZXIYBCABKAlCA+BBAUINCgtjcmVkZW50aWFscyKFAQoOU2lnbkluUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyEhQKDGFjY2Vzc190b2tlbhgCIAEoCRI7ChdhY2Nlc3NfdG9rZW5fZXhwaXJlc19hdBgDIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAiEAoOU2lnbk91dFJlcXVlc3QiFQoTUmVmcmVzaFRva2VuUmVxdWVzdCJcChRSZWZyZXNoVG9rZW5SZXNwb25zZRIUCgxhY2Nlc3NfdG9rZW4YASABKAkSLgoKZXhwaXJlc19hdBgCIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAyvwMKC0F1dGhTZXJ2aWNlEnQKDkdldEN1cnJlbnRVc2VyEiMubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVxdWVzdBokLm1lbW9zLmFwaS52MS5HZXRDdXJyZW50VXNlclJlc3BvbnNlIheC0+STAhESDy9hcGkvdjEvYXV0aC9tZRJjCgZTaWduSW4SGy5tZW1vcy5hcGkudjEuU2lnbkluUmVxdWVzdBocLm1lbW9zLmFwaS52MS5TaWduSW5SZXNwb25zZSIegtPkkwIYOgEqIhMvYXBpL3YxL2F1dGgvc2lnbmluEl0KB1NpZ25PdXQSHC5tZW1vcy5hcGkudjEuU2lnbk91dFJlcXVlc3QaFi5nb29nbGUucHJvdG9idWYuRW1wdHkiHILT5JMCFiIUL2FwaS92MS9hdXRoL3NpZ25vdXQSdgoMUmVmcmVzaFRva2VuEiEubWVtb3MuYXBpLnYxLlJlZnJlc2hUb2tlblJlcXVlc3QaIi5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVzcG9uc2UiH4LT5JMCGToBKiIUL2FwaS92MS9hdXRoL3JlZnJlc2hCqAEKEGNvbS5tZW1vcy5hcGkudjFCEEF1dGhTZXJ2aWNlUHJvdG9QAVowZ2l0aHViLmNvbS91c2VtZW1vcy9tZW1vcy9wcm90by9nZW4vYXBpL3YxO2FwaXYxogIDTUFYqgIMTWVtb3MuQXBpLlYxygIMTWVtb3NcQXBpXFYx4gIYTWVtb3NcQXBpXFYxXEdQQk1ldGFkYXRh6gIOTWVtb3M6OkFwaTo6VjFiBnByb3RvMw", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]);
fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIu4CCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIacQoOU1NPQ3JlZGVudGlhbHMSFQoIaWRwX25hbWUYASABKAlCA+BBAhIRCgRjb2RlGAIgASgJQgPgQQISGQoMcmVkaXJlY3RfdXJpGAMgASgJQgPgQQISGgoNY29kZV92ZXJpZmllchgEIAEoCUID4EEBQg0KC2NyZWRlbnRpYWxzIoUBCg5TaWduSW5SZXNwb25zZRIgCgR1c2VyGAEgASgLMhIubWVtb3MuYXBpLnYxLlVzZXISFAoMYWNjZXNzX3Rva2VuGAIgASgJEjsKF2FjY2Vzc190b2tlbl9leHBpcmVzX2F0GAMgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcCIQCg5TaWduT3V0UmVxdWVzdCIVChNSZWZyZXNoVG9rZW5SZXF1ZXN0IlwKFFJlZnJlc2hUb2tlblJlc3BvbnNlEhQKDGFjY2Vzc190b2tlbhgBIAEoCRIuCgpleHBpcmVzX2F0GAIgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcDK/AwoLQXV0aFNlcnZpY2USdAoOR2V0Q3VycmVudFVzZXISIy5tZW1vcy5hcGkudjEuR2V0Q3VycmVudFVzZXJSZXF1ZXN0GiQubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVzcG9uc2UiF4LT5JMCERIPL2FwaS92MS9hdXRoL21lEmMKBlNpZ25JbhIbLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0GhwubWVtb3MuYXBpLnYxLlNpZ25JblJlc3BvbnNlIh6C0+STAhg6ASoiEy9hcGkvdjEvYXV0aC9zaWduaW4SXQoHU2lnbk91dBIcLm1lbW9zLmFwaS52MS5TaWduT3V0UmVxdWVzdBoWLmdvb2dsZS5wcm90b2J1Zi5FbXB0eSIcgtPkkwIWIhQvYXBpL3YxL2F1dGgvc2lnbm91dBJ2CgxSZWZyZXNoVG9rZW4SIS5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVxdWVzdBoiLm1lbW9zLmFwaS52MS5SZWZyZXNoVG9rZW5SZXNwb25zZSIfgtPkkwIZOgEqIhQvYXBpL3YxL2F1dGgvcmVmcmVzaEKoAQoQY29tLm1lbW9zLmFwaS52MUIQQXV0aFNlcnZpY2VQcm90b1ABWjBnaXRodWIuY29tL3VzZW1lbW9zL21lbW9zL3Byb3RvL2dlbi9hcGkvdjE7YXBpdjGiAgNNQViqAgxNZW1vcy5BcGkuVjHKAgxNZW1vc1xBcGlcVjHiAhhNZW1vc1xBcGlcVjFcR1BCTWV0YWRhdGHqAg5NZW1vczo6QXBpOjpWMWIGcHJvdG8z", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]);
/**
* @generated from message memos.api.v1.GetCurrentUserRequest
@ -120,11 +120,12 @@ export const SignInRequest_PasswordCredentialsSchema: GenMessage<SignInRequest_P
*/
export type SignInRequest_SSOCredentials = Message<"memos.api.v1.SignInRequest.SSOCredentials"> & {
/**
* The ID of the SSO provider.
* The resource name of the SSO provider.
* Format: identity-providers/{uid}
*
* @generated from field: int32 idp_id = 1;
* @generated from field: string idp_name = 1;
*/
idpId: number;
idpName: string;
/**
* The authorization code from the SSO provider.

View File

@ -3,7 +3,7 @@ const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes
interface OAuthState {
state: string;
identityProviderId: number;
identityProviderName: string;
timestamp: number;
returnUrl?: string;
codeVerifier?: string; // PKCE code_verifier
@ -42,7 +42,10 @@ function base64UrlEncode(buffer: Uint8Array): string {
// Store OAuth state and PKCE parameters in sessionStorage
// Returns state and optional codeChallenge for use in authorization URL
// PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth
export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge?: string }> {
export async function storeOAuthState(
identityProviderName: string,
returnUrl?: string,
): Promise<{ state: string; codeChallenge?: string }> {
const state = generateSecureState();
// Try to generate PKCE parameters if crypto.subtle is available (HTTPS/localhost)
@ -70,7 +73,7 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st
const stateData: OAuthState = {
state,
identityProviderId,
identityProviderName,
timestamp: Date.now(),
returnUrl,
codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available)
@ -87,8 +90,8 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st
}
// Validate and retrieve OAuth state from storage (CSRF protection)
// Returns identityProviderId, returnUrl, and codeVerifier for PKCE
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null {
// Returns identityProviderName, returnUrl, and codeVerifier for PKCE
export function validateOAuthState(stateParam: string): { identityProviderName: string; returnUrl?: string; codeVerifier?: string } | null {
try {
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
if (!storedData) {
@ -115,7 +118,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu
// State is valid, clean up and return data
sessionStorage.removeItem(STATE_STORAGE_KEY);
return {
identityProviderId: stateData.identityProviderId,
identityProviderName: stateData.identityProviderName,
returnUrl: stateData.returnUrl,
codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier
};