From acbc914dea8d4e533028b03929193a4af1ac32b4 Mon Sep 17 00:00:00 2001 From: memoclaw Date: Mon, 30 Mar 2026 20:01:16 +0800 Subject: [PATCH 1/8] fix(webhooks): trigger memo updates for attachment and relation changes (#5795) Co-authored-by: memoclaw <265580040+memoclaw@users.noreply.github.com> --- .../router/api/v1/memo_attachment_service.go | 38 ++++++--- server/router/api/v1/memo_relation_service.go | 30 +++++-- server/router/api/v1/memo_service.go | 51 ++---------- server/router/api/v1/memo_update_helpers.go | 78 +++++++++++++++++ server/router/api/v1/sse_service_test.go | 83 +++++++++++++++++++ 5 files changed, 216 insertions(+), 64 deletions(-) create mode 100644 server/router/api/v1/memo_update_helpers.go diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index d9aa0a3ee..d687c59e9 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -35,20 +35,36 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set if memo.CreatorID != user.ID && !isSuperUser(user) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } + if err := s.setMemoAttachmentsInternal(ctx, memo, request.Attachments); err != nil { + return nil, err + } + if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil { + return nil, err + } + updatedMemo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to build updated memo state") + } + s.dispatchMemoUpdatedSideEffects(ctx, updatedMemo, parentMemo, memoMessage) + + return &emptypb.Empty{}, nil +} + +func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *store.Memo, requestAttachments []*v1pb.Attachment) error { attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ MemoID: &memo.ID, }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list attachments") + return status.Errorf(codes.Internal, "failed to list attachments") } // Delete attachments that are not in the request. for _, attachment := range attachments { found := false - for _, requestAttachment := range request.Attachments { + for _, requestAttachment := range requestAttachments { requestAttachmentUID, err := ExtractAttachmentUIDFromName(requestAttachment.Name) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err) + return status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err) } if attachment.UID == requestAttachmentUID { found = true @@ -60,24 +76,24 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set ID: int32(attachment.ID), MemoID: &memo.ID, }); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete attachment") + return status.Errorf(codes.Internal, "failed to delete attachment") } } } - slices.Reverse(request.Attachments) + slices.Reverse(requestAttachments) // Update attachments' memo_id in the request. - for index, attachment := range request.Attachments { + for index, attachment := range requestAttachments { attachmentUID, err := ExtractAttachmentUIDFromName(attachment.Name) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err) + return status.Errorf(codes.InvalidArgument, "invalid attachment name: %v", err) } tempAttachment, err := s.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get attachment: %v", err) + return status.Errorf(codes.Internal, "failed to get attachment: %v", err) } if tempAttachment == nil { - return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID) + return status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID) } updatedTs := time.Now().Unix() + int64(index) if err := s.Store.UpdateAttachment(ctx, &store.UpdateAttachment{ @@ -85,11 +101,11 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set MemoID: &memo.ID, UpdatedTs: &updatedTs, }); err != nil { - return nil, status.Errorf(codes.Internal, "failed to update attachment: %v", err) + return status.Errorf(codes.Internal, "failed to update attachment: %v", err) } } - return &emptypb.Empty{}, nil + return nil } func (s *APIV1Service) ListMemoAttachments(ctx context.Context, request *v1pb.ListMemoAttachmentsRequest) (*v1pb.ListMemoAttachmentsResponse, error) { diff --git a/server/router/api/v1/memo_relation_service.go b/server/router/api/v1/memo_relation_service.go index 25d97f24a..382d72143 100644 --- a/server/router/api/v1/memo_relation_service.go +++ b/server/router/api/v1/memo_relation_service.go @@ -35,18 +35,34 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe if memo.CreatorID != user.ID && !isSuperUser(user) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } + if err := s.setMemoRelationsInternal(ctx, memo, request.Relations); err != nil { + return nil, err + } + if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil { + return nil, err + } + updatedMemo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to build updated memo state") + } + s.dispatchMemoUpdatedSideEffects(ctx, updatedMemo, parentMemo, memoMessage) + + return &emptypb.Empty{}, nil +} + +func (s *APIV1Service) setMemoRelationsInternal(ctx context.Context, memo *store.Memo, relations []*v1pb.MemoRelation) error { referenceType := store.MemoRelationReference // Delete all reference relations first. if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ MemoID: &memo.ID, Type: &referenceType, }); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete memo relation") + return status.Errorf(codes.Internal, "failed to delete memo relation") } - for _, relation := range request.Relations { + for _, relation := range relations { // Ignore reflexive relations. - if request.Name == relation.RelatedMemo.Name { + if buildMemoName(memo.UID) == relation.RelatedMemo.Name { continue } // Ignore comment relations as there's no need to update a comment's relation. @@ -56,22 +72,22 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe } relatedMemoUID, err := ExtractMemoUIDFromName(relation.RelatedMemo.Name) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err) + return status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err) } relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &relatedMemoUID}) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get related memo") + return status.Errorf(codes.Internal, "failed to get related memo") } if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{ MemoID: memo.ID, RelatedMemoID: relatedMemo.ID, Type: convertMemoRelationTypeToStore(relation.Type), }); err != nil { - return nil, status.Errorf(codes.Internal, "failed to upsert memo relation") + return status.Errorf(codes.Internal, "failed to upsert memo relation") } } - return &emptypb.Empty{}, nil + return nil } func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) { diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 31016ae1e..830ee1eb6 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -469,19 +469,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR payload.Location = convertLocationToStore(request.Memo.Location) update.Payload = payload } else if path == "attachments" { - _, err := s.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{ - Name: request.Memo.Name, - Attachments: request.Memo.Attachments, - }) - if err != nil { + if err := s.setMemoAttachmentsInternal(ctx, memo, request.Memo.Attachments); err != nil { return nil, errors.Wrap(err, "failed to set memo attachments") } } else if path == "relations" { - _, err := s.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{ - Name: request.Memo.Name, - Relations: request.Memo.Relations, - }) - if err != nil { + if err := s.setMemoRelationsInternal(ctx, memo, request.Memo.Relations); err != nil { return nil, errors.Wrap(err, "failed to set memo relations") } } @@ -497,44 +489,11 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR if err != nil { return nil, errors.Wrap(err, "failed to get memo") } - reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ - ContentID: &request.Memo.Name, - }) + memo, parentMemo, memoMessage, err := s.buildUpdatedMemoState(ctx, memo.ID) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list reactions") + return nil, errors.Wrap(err, "failed to build updated memo state") } - attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ - MemoID: &memo.ID, - }) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list attachments") - } - - relations, err := s.loadMemoRelations(ctx, memo) - if err != nil { - return nil, errors.Wrap(err, "failed to load memo relations") - } - memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations) - if err != nil { - return nil, errors.Wrap(err, "failed to convert memo") - } - var parentMemo *store.Memo - if memo.ParentUID != nil { - parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) - } - // Try to dispatch webhook when memo is updated. - if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil { - slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err)) - } - - // Broadcast live refresh event. - s.SSEHub.Broadcast(&SSEEvent{ - Type: SSEEventMemoUpdated, - Name: memoMessage.Name, - Parent: memoMessage.GetParent(), - Visibility: memo.Visibility, - CreatorID: resolveSSECreatorID(memo, parentMemo), - }) + s.dispatchMemoUpdatedSideEffects(ctx, memo, parentMemo, memoMessage) return memoMessage, nil } diff --git a/server/router/api/v1/memo_update_helpers.go b/server/router/api/v1/memo_update_helpers.go new file mode 100644 index 000000000..2c5f7e044 --- /dev/null +++ b/server/router/api/v1/memo_update_helpers.go @@ -0,0 +1,78 @@ +package v1 + +import ( + "context" + "log/slog" + "time" + + "github.com/pkg/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" + "github.com/usememos/memos/store" +) + +func (s *APIV1Service) touchMemoUpdatedTimestamp(ctx context.Context, memoID int32) error { + updatedTs := time.Now().Unix() + if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{ + ID: memoID, + UpdatedTs: &updatedTs, + }); err != nil { + return status.Errorf(codes.Internal, "failed to update memo timestamp") + } + return nil +} + +func (s *APIV1Service) buildUpdatedMemoState(ctx context.Context, memoID int32) (*store.Memo, *store.Memo, *v1pb.Memo, error) { + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID}) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to get memo") + } + if memo == nil { + return nil, nil, nil, errors.New("memo not found") + } + + memoName := buildMemoName(memo.UID) + reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ + ContentID: &memoName, + }) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to list reactions") + } + attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ + MemoID: &memo.ID, + }) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to list attachments") + } + relations, err := s.loadMemoRelations(ctx, memo) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to load memo relations") + } + memoMessage, err := s.convertMemoFromStore(ctx, memo, reactions, attachments, relations) + if err != nil { + return nil, nil, nil, errors.Wrap(err, "failed to convert memo") + } + + var parentMemo *store.Memo + if memo.ParentUID != nil { + parentMemo, _ = s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) + } + + return memo, parentMemo, memoMessage, nil +} + +func (s *APIV1Service) dispatchMemoUpdatedSideEffects(ctx context.Context, memo *store.Memo, parentMemo *store.Memo, memoMessage *v1pb.Memo) { + if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil { + slog.Warn("Failed to dispatch memo updated webhook", slog.Any("err", err)) + } + + s.SSEHub.Broadcast(&SSEEvent{ + Type: SSEEventMemoUpdated, + Name: memoMessage.Name, + Parent: memoMessage.GetParent(), + Visibility: memo.Visibility, + CreatorID: resolveSSECreatorID(memo, parentMemo), + }) +} diff --git a/server/router/api/v1/sse_service_test.go b/server/router/api/v1/sse_service_test.go index 180b2aec6..689a35b31 100644 --- a/server/router/api/v1/sse_service_test.go +++ b/server/router/api/v1/sse_service_test.go @@ -187,3 +187,86 @@ func TestDeleteMemoReaction_SSEEvent(t *testing.T) { assert.Contains(t, payload, memo.Name) mustNotReceive(t, client.events, 100*time.Millisecond) } + +func TestSetMemoAttachments_EmitsMemoUpdatedSSEEvent(t *testing.T) { + ctx := context.Background() + svc := newIntegrationService(t) + + user, err := svc.Store.CreateUser(ctx, &store.User{ + Username: "user", Role: store.RoleAdmin, Email: "user@example.com", + }) + require.NoError(t, err) + uctx := userCtx(ctx, user.ID) + + memo, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{Content: "memo with attachments", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + + attachment, err := svc.CreateAttachment(uctx, &v1pb.CreateAttachmentRequest{ + Attachment: &v1pb.Attachment{ + Filename: "test.txt", + Size: 5, + Type: "text/plain", + Content: []byte("hello"), + }, + }) + require.NoError(t, err) + + client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin) + defer svc.SSEHub.Unsubscribe(client) + + _, err = svc.SetMemoAttachments(uctx, &v1pb.SetMemoAttachmentsRequest{ + Name: memo.Name, + Attachments: []*v1pb.Attachment{ + {Name: attachment.Name}, + }, + }) + require.NoError(t, err) + + data := mustReceive(t, client.events, time.Second) + payload := string(data) + assert.Contains(t, payload, `"memo.updated"`) + assert.Contains(t, payload, memo.Name) + mustNotReceive(t, client.events, 100*time.Millisecond) +} + +func TestSetMemoRelations_EmitsMemoUpdatedSSEEvent(t *testing.T) { + ctx := context.Background() + svc := newIntegrationService(t) + + user, err := svc.Store.CreateUser(ctx, &store.User{ + Username: "user", Role: store.RoleAdmin, Email: "user@example.com", + }) + require.NoError(t, err) + uctx := userCtx(ctx, user.ID) + + memo1, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{Content: "memo one", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + memo2, err := svc.CreateMemo(uctx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{Content: "memo two", Visibility: v1pb.Visibility_PUBLIC}, + }) + require.NoError(t, err) + + client := svc.SSEHub.Subscribe(user.ID, store.RoleAdmin) + defer svc.SSEHub.Unsubscribe(client) + + _, err = svc.SetMemoRelations(uctx, &v1pb.SetMemoRelationsRequest{ + Name: memo1.Name, + Relations: []*v1pb.MemoRelation{ + { + RelatedMemo: &v1pb.MemoRelation_Memo{Name: memo2.Name}, + Type: v1pb.MemoRelation_REFERENCE, + }, + }, + }) + require.NoError(t, err) + + data := mustReceive(t, client.events, time.Second) + payload := string(data) + assert.Contains(t, payload, `"memo.updated"`) + assert.Contains(t, payload, memo1.Name) + mustNotReceive(t, client.events, 100*time.Millisecond) +} From e520b637fd8d0f6331674003abe72f9dfbae8231 Mon Sep 17 00:00:00 2001 From: memoclaw <265580040+memoclaw@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:37:07 +0800 Subject: [PATCH 2/8] fix: prevent stale comment drafts from being restored --- .../MemoEditor/services/cacheService.ts | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/web/src/components/MemoEditor/services/cacheService.ts b/web/src/components/MemoEditor/services/cacheService.ts index bc311952e..5c93e3b4f 100644 --- a/web/src/components/MemoEditor/services/cacheService.ts +++ b/web/src/components/MemoEditor/services/cacheService.ts @@ -1,25 +1,49 @@ -import { debounce } from "lodash-es"; - export const CACHE_DEBOUNCE_DELAY = 500; +const pendingSaves = new Map>(); + export const cacheService = { key: (username: string, cacheKey?: string): string => { return `${username}-${cacheKey || ""}`; }, - save: debounce((key: string, content: string) => { - if (content.trim()) { - localStorage.setItem(key, content); - } else { - localStorage.removeItem(key); + save: (key: string, content: string) => { + const pendingSave = pendingSaves.get(key); + if (pendingSave) { + window.clearTimeout(pendingSave); } - }, CACHE_DEBOUNCE_DELAY), + + const timeoutId = window.setTimeout(() => { + pendingSaves.delete(key); + + if (content.trim()) { + localStorage.setItem(key, content); + } else { + localStorage.removeItem(key); + } + }, CACHE_DEBOUNCE_DELAY); + + pendingSaves.set(key, timeoutId); + }, load(key: string): string { return localStorage.getItem(key) || ""; }, clear(key: string): void { + const pendingSave = pendingSaves.get(key); + if (pendingSave) { + window.clearTimeout(pendingSave); + pendingSaves.delete(key); + } + localStorage.removeItem(key); }, + + clearAll(): void { + for (const timeoutId of pendingSaves.values()) { + window.clearTimeout(timeoutId); + } + pendingSaves.clear(); + }, }; From 7c708ee27e323688e7a40538f80bd12b8f0df90d Mon Sep 17 00:00:00 2001 From: boojack Date: Mon, 30 Mar 2026 23:51:57 +0800 Subject: [PATCH 3/8] chore: add migration upgrade coverage (#5796) --- .../api/v1/test/memo_share_service_test.go | 106 +++++++ store/test/containers.go | 62 +++- store/test/migrator_upgrade_test.go | 274 ++++++++++++++++++ 3 files changed, 437 insertions(+), 5 deletions(-) create mode 100644 store/test/migrator_upgrade_test.go diff --git a/server/router/api/v1/test/memo_share_service_test.go b/server/router/api/v1/test/memo_share_service_test.go index 110b83a35..946597d92 100644 --- a/server/router/api/v1/test/memo_share_service_test.go +++ b/server/router/api/v1/test/memo_share_service_test.go @@ -4,12 +4,14 @@ import ( "context" "strings" "testing" + "time" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" apiv1 "github.com/usememos/memos/proto/gen/api/v1" + "github.com/usememos/memos/store" ) func TestDeleteMemoShare_VerifiesShareBelongsToMemo(t *testing.T) { @@ -107,3 +109,107 @@ func TestGetMemoByShare_IncludesReactions(t *testing.T) { require.Equal(t, "👍", sharedMemo.Reactions[0].ReactionType) require.Equal(t, memo.Name, sharedMemo.Reactions[0].ContentId) } + +func TestGetMemoByShare_ReturnsNotFoundForUnknownShare(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{ + ShareId: "missing-share-token", + }) + require.Error(t, err) + require.Equal(t, codes.NotFound, status.Code(err)) +} + +func TestGetMemoByShare_ReturnsNotFoundForExpiredShare(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "share-expired") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "memo with expired share", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + + expiredTs := time.Now().Add(-time.Hour).Unix() + expiredShare, err := ts.Store.CreateMemoShare(ctx, &store.MemoShare{ + UID: "expired-share-token", + MemoID: parseMemoIDFromNameForTest(t, ts, memo.Name), + CreatorID: user.ID, + ExpiresTs: &expiredTs, + }) + require.NoError(t, err) + + _, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{ + ShareId: expiredShare.UID, + }) + require.Error(t, err) + require.Equal(t, codes.NotFound, status.Code(err)) +} + +func TestGetMemoByShare_ReturnsNotFoundForArchivedMemo(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "share-archived") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + memoResp, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "memo that will be archived", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + + share, err := ts.Service.CreateMemoShare(userCtx, &apiv1.CreateMemoShareRequest{ + Parent: memoResp.Name, + MemoShare: &apiv1.MemoShare{}, + }) + require.NoError(t, err) + + memoID := parseMemoIDFromNameForTest(t, ts, memoResp.Name) + memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID}) + require.NoError(t, err) + require.NotNil(t, memo) + + archived := store.Archived + err = ts.Store.UpdateMemo(ctx, &store.UpdateMemo{ + ID: memo.ID, + RowStatus: &archived, + }) + require.NoError(t, err) + + shareToken := share.Name[strings.LastIndex(share.Name, "/")+1:] + _, err = ts.Service.GetMemoByShare(ctx, &apiv1.GetMemoByShareRequest{ + ShareId: shareToken, + }) + require.Error(t, err) + require.Equal(t, codes.NotFound, status.Code(err)) +} + +func parseMemoIDFromNameForTest(t *testing.T, ts *TestService, memoName string) int32 { + t.Helper() + + memoUID, ok := strings.CutPrefix(memoName, "memos/") + require.True(t, ok, "memo name must start with memos/: %s", memoName) + + memo, err := ts.Store.GetMemo(context.Background(), &store.FindMemo{UID: &memoUID}) + require.NoError(t, err) + require.NotNil(t, memo) + + return memo.ID +} diff --git a/store/test/containers.go b/store/test/containers.go index 0b98c5c2b..e9760b6a1 100644 --- a/store/test/containers.go +++ b/store/test/containers.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + "net" + "net/url" "os" "strings" "sync" @@ -12,6 +14,7 @@ import ( "time" "github.com/docker/docker/api/types/container" + mysqldriver "github.com/go-sql-driver/mysql" "github.com/pkg/errors" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" @@ -20,7 +23,6 @@ import ( "github.com/testcontainers/testcontainers-go/wait" // Database drivers for connection verification. - _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" ) @@ -31,6 +33,9 @@ const ( // Memos container settings for migration testing. MemosDockerImage = "neosmemo/memos" StableMemosVersion = "stable" // Always points to the latest stable release + + mysqlNetworkAlias = "memos-mysql" + postgresNetworkAlias = "memos-postgres" ) var ( @@ -62,12 +67,23 @@ func getTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) return testDockerNetwork.Load(), networkErr } +func requireTestNetwork(ctx context.Context) (*testcontainers.DockerNetwork, error) { + nw, err := getTestNetwork(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to create test network") + } + if nw == nil { + return nil, errors.New("test network is unavailable") + } + return nw, nil +} + // GetMySQLDSN starts a MySQL container (if not already running) and creates a fresh database for this test. func GetMySQLDSN(t *testing.T) string { ctx := context.Background() mysqlOnce.Do(func() { - nw, err := getTestNetwork(ctx) + nw, err := requireTestNetwork(ctx) if err != nil { t.Fatalf("failed to create test network: %v", err) } @@ -86,7 +102,7 @@ func GetMySQLDSN(t *testing.T) string { wait.ForListeningPort("3306/tcp"), ).WithDeadline(120*time.Second), ), - network.WithNetwork(nil, nw), + network.WithNetwork([]string{mysqlNetworkAlias}, nw), ) if err != nil { t.Fatalf("failed to start MySQL container: %v", err) @@ -167,7 +183,7 @@ func GetPostgresDSN(t *testing.T) string { ctx := context.Background() postgresOnce.Do(func() { - nw, err := getTestNetwork(ctx) + nw, err := requireTestNetwork(ctx) if err != nil { t.Fatalf("failed to create test network: %v", err) } @@ -183,7 +199,7 @@ func GetPostgresDSN(t *testing.T) string { wait.ForListeningPort("5432/tcp"), ).WithDeadline(120*time.Second), ), - network.WithNetwork(nil, nw), + network.WithNetwork([]string{postgresNetworkAlias}, nw), ) if err != nil { t.Fatalf("failed to start PostgreSQL container: %v", err) @@ -264,6 +280,11 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon "MEMOS_MODE": "prod", } + nw, err := requireTestNetwork(ctx) + if err != nil { + return nil, err + } + var opts []testcontainers.ContainerCustomizer switch cfg.Driver { @@ -272,6 +293,12 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon opts = append(opts, testcontainers.WithHostConfigModifier(func(hc *container.HostConfig) { hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", cfg.DataDir, "/var/opt/memos")) })) + case "mysql", "postgres": + if cfg.DSN == "" { + return nil, errors.Errorf("dsn is required for %s migration testing", cfg.Driver) + } + env["MEMOS_DRIVER"] = cfg.Driver + env["MEMOS_DSN"] = cfg.DSN default: return nil, errors.Errorf("unsupported driver for migration testing: %s", cfg.Driver) } @@ -303,6 +330,7 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon } // Apply options + opts = append(opts, network.WithNetwork(nil, nw)) for _, opt := range opts { if err := opt.Customize(&genericReq); err != nil { return nil, errors.Wrap(err, "failed to apply container option") @@ -316,3 +344,27 @@ func StartMemosContainer(ctx context.Context, cfg MemosContainerConfig) (testcon return ctr, nil } + +func getContainerDSN(driver, hostDSN string) (string, error) { + switch driver { + case "mysql": + cfg, err := mysqldriver.ParseDSN(hostDSN) + if err != nil { + return "", errors.Wrap(err, "failed to parse mysql dsn") + } + cfg.Net = "tcp" + cfg.Addr = net.JoinHostPort(mysqlNetworkAlias, "3306") + return cfg.FormatDSN(), nil + case "postgres": + u, err := url.Parse(hostDSN) + if err != nil { + return "", errors.Wrap(err, "failed to parse postgres dsn") + } + u.Host = net.JoinHostPort(postgresNetworkAlias, "5432") + return u.String(), nil + case "sqlite": + return hostDSN, nil + default: + return "", errors.Errorf("unsupported driver for container dsn: %s", driver) + } +} diff --git a/store/test/migrator_upgrade_test.go b/store/test/migrator_upgrade_test.go new file mode 100644 index 000000000..4b8254ed1 --- /dev/null +++ b/store/test/migrator_upgrade_test.go @@ -0,0 +1,274 @@ +package test + +import ( + "context" + "database/sql" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func TestMigrationFromV0262PreservesLegacyData(t *testing.T) { + if testing.Short() { + t.Skip("skipping container-based upgrade test in short mode") + } + if os.Getenv("SKIP_CONTAINER_TESTS") == "1" { + t.Skip("skipping container-based test (SKIP_CONTAINER_TESTS=1)") + } + + ctx := context.Background() + driver := getDriverFromEnv() + + cfg, hostDSN := prepareV0262MigrationTest(t, driver) + t.Logf("Starting Memos %s container for %s schema bootstrap...", cfg.Version, driver) + container, err := StartMemosContainer(ctx, cfg) + require.NoError(t, err, "failed to start v0.26.2 memos container") + t.Cleanup(func() { + if container != nil { + _ = container.Terminate(ctx) + } + }) + + legacyStore := NewTestingStoreWithDSN(ctx, t, driver, hostDSN) + require.Eventually(t, func() bool { + setting, err := legacyStore.GetInstanceBasicSetting(ctx) + return err == nil && setting != nil && setting.SchemaVersion != "" + }, 45*time.Second, 500*time.Millisecond, "legacy schema should be initialized by old container") + + settingBeforeSeed, err := legacyStore.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + t.Logf("Legacy schema version before migration: %s", settingBeforeSeed.SchemaVersion) + + err = container.Terminate(ctx) + require.NoError(t, err, "failed to stop v0.26.2 memos container") + container = nil + + db := openMigrationSQLDB(t, driver, hostDSN) + defer db.Close() + + seedLegacyMigrationData(ctx, t, driver, db) + + count, err := countSystemSetting(ctx, db, "STORAGE") + require.NoError(t, err) + require.Zero(t, count, "v0.26.2 database should not have a STORAGE setting before migration") + + ts := NewTestingStoreWithDSN(ctx, t, driver, hostDSN) + err = ts.Migrate(ctx) + require.NoError(t, err, "migration from v0.26.2 should succeed for %s", driver) + + currentVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + currentSetting, err := ts.GetInstanceBasicSetting(ctx) + require.NoError(t, err) + require.Equal(t, currentVersion, currentSetting.SchemaVersion, "schema version should be updated") + + storageSetting, err := ts.GetInstanceStorageSetting(ctx) + require.NoError(t, err) + require.Equal(t, storepb.InstanceStorageSetting_DATABASE, storageSetting.StorageType, "existing installs should stay on DATABASE storage") + + idps, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) + require.NoError(t, err) + require.Len(t, idps, 2) + idpUIDsByName := map[string]string{} + for _, idp := range idps { + idpUIDsByName[idp.Name] = idp.Uid + } + require.Equal(t, "00000191", idpUIDsByName["Legacy Google"]) + require.Equal(t, "00000192", idpUIDsByName["Legacy GitHub"]) + + inboxes, err := ts.ListInboxes(ctx, &store.FindInbox{}) + require.NoError(t, err) + require.Len(t, inboxes, 1) + require.NotNil(t, inboxes[0].Message) + require.Equal(t, storepb.InboxMessage_MEMO_COMMENT, inboxes[0].Message.Type) + require.Equal(t, int32(102), inboxes[0].Message.GetMemoComment().MemoId) + require.Equal(t, int32(101), inboxes[0].Message.GetMemoComment().RelatedMemoId) + + activityExists, err := tableExists(ctx, db, driver, "activity") + require.NoError(t, err) + require.False(t, activityExists, "activity table should be removed after migration") + + memoShareExists, err := tableExists(ctx, db, driver, "memo_share") + require.NoError(t, err) + require.True(t, memoShareExists, "memo_share table should be created") + + share, err := ts.CreateMemoShare(ctx, &store.MemoShare{ + UID: "post-upgrade-share", + MemoID: 101, + CreatorID: 11, + }) + require.NoError(t, err) + require.Equal(t, "post-upgrade-share", share.UID) + + postUpgradeUser, err := createTestingUserWithRole(ctx, ts, "postupgrade", store.RoleUser) + require.NoError(t, err) + postUpgradeMemo, err := ts.CreateMemo(ctx, &store.Memo{ + UID: "post-upgrade-memo-v0262", + CreatorID: postUpgradeUser.ID, + Content: "created after v0.26.2 migration", + Visibility: store.Public, + }) + require.NoError(t, err) + require.Equal(t, "created after v0.26.2 migration", postUpgradeMemo.Content) +} + +func prepareV0262MigrationTest(t *testing.T, driver string) (MemosContainerConfig, string) { + t.Helper() + + const version = "0.26.2" + + switch driver { + case "sqlite": + dataDir := t.TempDir() + return MemosContainerConfig{ + Version: version, + Driver: driver, + DataDir: dataDir, + }, fmt.Sprintf("%s/memos_prod.db", dataDir) + case "mysql": + hostDSN := GetMySQLDSN(t) + containerDSN, err := getContainerDSN(driver, hostDSN) + require.NoError(t, err) + return MemosContainerConfig{ + Version: version, + Driver: driver, + DSN: containerDSN, + }, hostDSN + case "postgres": + hostDSN := GetPostgresDSN(t) + containerDSN, err := getContainerDSN(driver, hostDSN) + require.NoError(t, err) + return MemosContainerConfig{ + Version: version, + Driver: driver, + DSN: containerDSN, + }, hostDSN + default: + t.Fatalf("unsupported driver: %s", driver) + return MemosContainerConfig{}, "" + } +} + +func openMigrationSQLDB(t *testing.T, driver, dsn string) *sql.DB { + t.Helper() + + db, err := sql.Open(driver, dsn) + require.NoError(t, err) + require.NoError(t, db.Ping()) + return db +} + +func seedLegacyMigrationData(ctx context.Context, t *testing.T, driver string, db *sql.DB) { + t.Helper() + + execMigrationSQL(t, db, legacyInsertUserSQL(driver, 11, "owner")) + execMigrationSQL(t, db, legacyInsertUserSQL(driver, 12, "commenter")) + execMigrationSQL(t, db, legacyInsertMemoSQL(101, 11, "legacy-parent", "parent memo")) + execMigrationSQL(t, db, legacyInsertMemoSQL(102, 12, "legacy-comment", "comment memo")) + execMigrationSQL(t, db, legacyInsertActivitySQL(201, 12)) + execMigrationSQL(t, db, legacyInsertInboxSQL(301, 12, 11, 201)) + execMigrationSQL(t, db, legacyInsertIDPSQL(401, "Legacy Google")) + execMigrationSQL(t, db, legacyInsertIDPSQL(402, "Legacy GitHub")) + + var message string + err := db.QueryRowContext(ctx, "SELECT message FROM inbox WHERE id = 301").Scan(&message) + require.NoError(t, err) + require.Contains(t, message, "\"activityId\":201") + require.NotContains(t, message, "\"memoComment\"") +} + +func execMigrationSQL(t *testing.T, db *sql.DB, query string) { + t.Helper() + _, err := db.Exec(query) + require.NoError(t, err, "failed to execute SQL: %s", query) +} + +func legacyInsertUserSQL(driver string, id int, username string) string { + table := "user" + switch driver { + case "mysql": + table = "`user`" + case "postgres", "sqlite": + table = `"user"` + default: + // Keep the unquoted fallback for unknown test drivers. + } + + return fmt.Sprintf( + "INSERT INTO %s (id, username, role, email, nickname, password_hash, avatar_url, description) VALUES (%d, '%s', 'USER', '%s@example.com', '%s', 'legacy-hash', '', 'legacy user')", + table, id, username, username, username, + ) +} + +func legacyInsertMemoSQL(id, creatorID int, uid, content string) string { + payload := "{}" + return fmt.Sprintf( + "INSERT INTO memo (id, uid, creator_id, content, visibility, payload) VALUES (%d, '%s', %d, '%s', 'PRIVATE', '%s')", + id, uid, creatorID, content, payload, + ) +} + +func legacyInsertActivitySQL(id, creatorID int) string { + payload := `{"memoComment":{"memoId":102,"relatedMemoId":101}}` + return fmt.Sprintf( + "INSERT INTO activity (id, creator_id, type, level, payload) VALUES (%d, %d, 'MEMO_COMMENT', 'INFO', '%s')", + id, creatorID, payload, + ) +} + +func legacyInsertInboxSQL(id, senderID, receiverID, activityID int) string { + message := fmt.Sprintf(`{"type":"MEMO_COMMENT","activityId":%d}`, activityID) + return fmt.Sprintf( + "INSERT INTO inbox (id, sender_id, receiver_id, status, message) VALUES (%d, %d, %d, 'UNREAD', '%s')", + id, senderID, receiverID, message, + ) +} + +func legacyInsertIDPSQL(id int, name string) string { + config := `{"clientId":"legacy-client","clientSecret":"legacy-secret","authUrl":"https://example.com/auth","tokenUrl":"https://example.com/token","userInfoUrl":"https://example.com/userinfo"}` + return fmt.Sprintf( + "INSERT INTO idp (id, name, type, identifier_filter, config) VALUES (%d, '%s', 'OAUTH2', '', '%s')", + id, name, config, + ) +} + +func countSystemSetting(ctx context.Context, db *sql.DB, name string) (int, error) { + var count int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = ?", name).Scan(&count) + if err == nil { + return count, nil + } + + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_setting WHERE name = $1", name).Scan(&count) + return count, err +} + +func tableExists(ctx context.Context, db *sql.DB, driver, table string) (bool, error) { + switch driver { + case "sqlite": + var name string + err := db.QueryRowContext(ctx, "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name) + if err == sql.ErrNoRows { + return false, nil + } + return err == nil, err + case "mysql": + var count int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?", table).Scan(&count) + return count > 0, err + case "postgres": + var regclass sql.NullString + err := db.QueryRowContext(ctx, "SELECT to_regclass($1)", "public."+table).Scan(®class) + return regclass.Valid && strings.EqualFold(regclass.String, table), err + default: + return false, errors.Errorf("unsupported driver: %s", driver) + } +} From d3f6e8ee31e7c1942ac99f71d5ca30c3ffa410b1 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 31 Mar 2026 00:12:28 +0800 Subject: [PATCH 4/8] chore: harden MCP access control and origin validation --- server/router/mcp/README.md | 20 +- server/router/mcp/access.go | 113 +++++++++++ server/router/mcp/mcp.go | 18 +- server/router/mcp/mcp_test.go | 275 ++++++++++++++++++++++++++ server/router/mcp/tools_attachment.go | 20 +- server/router/mcp/tools_memo.go | 39 +--- server/router/mcp/tools_relation.go | 31 ++- server/router/mcp/tools_tag.go | 2 +- 8 files changed, 452 insertions(+), 66 deletions(-) create mode 100644 server/router/mcp/access.go create mode 100644 server/router/mcp/mcp_test.go diff --git a/server/router/mcp/README.md b/server/router/mcp/README.md index 86cb16b1e..78feb5732 100644 --- a/server/router/mcp/README.md +++ b/server/router/mcp/README.md @@ -7,6 +7,7 @@ This package implements a [Model Context Protocol (MCP)](https://modelcontextpro ``` POST /mcp (tool calls, initialize) GET /mcp (optional SSE stream for server-to-client messages) +DELETE /mcp (optional session termination) ``` Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26). @@ -24,13 +25,22 @@ The server advertises the following MCP capabilities: ## Authentication -Every request must include a Personal Access Token (PAT): +Public reads can be used without authentication. Personal Access Tokens (PATs) or short-lived JWT session tokens are required for: + +- Reading non-public memos or attachments +- Any tool that mutates data + +When authenticating, send a Bearer token: ``` Authorization: Bearer ``` -PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests without a valid token receive `HTTP 401`. +PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests with an invalid token receive `HTTP 401`. + +## Origin Validation + +For Streamable HTTP safety, requests with an `Origin` header must be same-origin with the current request host or match the configured `instance-url`. Requests without an `Origin` header, such as desktop MCP clients and CLI tools, are allowed. ## Tools @@ -60,15 +70,15 @@ PATs are long-lived tokens created in Settings → My Account → Access Tokens. | `list_attachments` | List user's attachments | — | `page_size`, `page`, `memo` | | `get_attachment` | Get attachment metadata | `name` | — | | `delete_attachment` | Delete an attachment | `name` | — | -| `link_attachment_to_memo` | Link attachment to memo | `name`, `memo` | — | +| `link_attachment_to_memo` | Link attachment to a memo you own | `name`, `memo` | — | ### Relation Tools | Tool | Description | Required params | Optional params | |---|---|---|---| | `list_memo_relations` | List relations (refs + comments) | `name` | `type` | -| `create_memo_relation` | Create a reference relation | `name`, `related_memo` | — | -| `delete_memo_relation` | Delete a reference relation | `name`, `related_memo` | — | +| `create_memo_relation` | Create a reference relation from a memo you own to a memo you can read | `name`, `related_memo` | — | +| `delete_memo_relation` | Delete a reference relation from a memo you own | `name`, `related_memo` | — | ### Reaction Tools diff --git a/server/router/mcp/access.go b/server/router/mcp/access.go new file mode 100644 index 000000000..0e950b228 --- /dev/null +++ b/server/router/mcp/access.go @@ -0,0 +1,113 @@ +package mcp + +import ( + "context" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +// checkMemoAccess returns an error if the caller cannot read the memo. +// userID == 0 means anonymous. +func checkMemoAccess(memo *store.Memo, userID int32) error { + if memo.RowStatus == store.Archived && memo.CreatorID != userID { + return errors.New("permission denied") + } + + switch memo.Visibility { + case store.Protected: + if userID == 0 { + return errors.New("permission denied") + } + case store.Private: + if memo.CreatorID != userID { + return errors.New("permission denied") + } + default: + // store.Public and any unknown visibility: allow. + } + return nil +} + +func checkMemoOwnership(memo *store.Memo, userID int32) error { + if memo.CreatorID != userID { + return errors.New("permission denied") + } + return nil +} + +// applyVisibilityFilter restricts find to memos the caller may see. +func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) { + if rowStatus != nil && *rowStatus == store.Archived { + if userID == 0 { + impossibleCreatorID := int32(-1) + find.CreatorID = &impossibleCreatorID + return + } + find.CreatorID = &userID + return + } + if userID == 0 { + find.VisibilityList = []store.Visibility{store.Public} + return + } + find.Filters = append(find.Filters, "creator_id == "+itoa32(userID)+` || visibility in ["PUBLIC", "PROTECTED"]`) +} + +func (s *MCPService) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment, userID int32) error { + if attachment.CreatorID == userID { + return nil + } + if attachment.MemoID == nil { + return errors.New("permission denied") + } + + memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID}) + if err != nil { + return errors.Wrap(err, "failed to get linked memo") + } + if memo == nil { + return errors.New("linked memo not found") + } + return checkMemoAccess(memo, userID) +} + +func (s *MCPService) isAllowedOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Scheme == "" || originURL.Host == "" { + return false + } + + if sameOriginHost(originURL.Host, r.Host) { + return true + } + + if s.profile.InstanceURL == "" { + return false + } + + instanceURL, err := url.Parse(s.profile.InstanceURL) + if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" { + return false + } + + return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && sameOriginHost(originURL.Host, instanceURL.Host) +} + +func sameOriginHost(a, b string) bool { + return strings.EqualFold(a, b) +} + +func itoa32(v int32) string { + return strconv.FormatInt(int64(v), 10) +} diff --git a/server/router/mcp/mcp.go b/server/router/mcp/mcp.go index dc499487c..93dcb6c82 100644 --- a/server/router/mcp/mcp.go +++ b/server/router/mcp/mcp.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/labstack/echo/v5" - "github.com/labstack/echo/v5/middleware" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/usememos/memos/internal/profile" @@ -44,11 +43,22 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) mcpGroup := echoServer.Group("") - mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: []string{"*"}, - })) mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { + if !s.isAllowedOrigin(c.Request()) { + return c.JSON(http.StatusForbidden, map[string]string{"message": "invalid origin"}) + } + if origin := c.Request().Header.Get("Origin"); origin != "" { + headers := c.Response().Header() + headers.Set("Vary", "Origin") + headers.Set("Access-Control-Allow-Origin", origin) + headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID") + headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + if c.Request().Method == http.MethodOptions { + return c.NoContent(http.StatusNoContent) + } + } + authHeader := c.Request().Header.Get("Authorization") if authHeader != "" { result := s.authenticator.Authenticate(c.Request().Context(), authHeader) diff --git a/server/router/mcp/mcp_test.go b/server/router/mcp/mcp_test.go new file mode 100644 index 000000000..a4dd1c489 --- /dev/null +++ b/server/router/mcp/mcp_test.go @@ -0,0 +1,275 @@ +package mcp + +import ( + "context" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/lithammer/shortuuid/v4" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/internal/profile" + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" + teststore "github.com/usememos/memos/store/test" +) + +type testMCPService struct { + service *MCPService + store *store.Store +} + +func newTestMCPService(t *testing.T) *testMCPService { + t.Helper() + + ctx := context.Background() + stores := teststore.NewTestingStore(ctx, t) + t.Cleanup(func() { + require.NoError(t, stores.Close()) + }) + + svc := NewMCPService(&profile.Profile{ + Driver: "sqlite", + InstanceURL: "https://notes.example.com", + }, stores, "test-secret") + return &testMCPService{ + service: svc, + store: stores, + } +} + +func (s *testMCPService) createUser(t *testing.T, username string) *store.User { + t.Helper() + + user, err := s.store.CreateUser(context.Background(), &store.User{ + Username: username, + Role: store.RoleUser, + Email: username + "@example.com", + }) + require.NoError(t, err) + return user +} + +func (s *testMCPService) createMemo(t *testing.T, creatorID int32, visibility store.Visibility, content string) *store.Memo { + t.Helper() + + memo, err := s.store.CreateMemo(context.Background(), &store.Memo{ + UID: shortuuid.New(), + CreatorID: creatorID, + RowStatus: store.Normal, + Visibility: visibility, + Content: content, + }) + require.NoError(t, err) + return memo +} + +func (s *testMCPService) archiveMemo(t *testing.T, memoID int32) { + t.Helper() + + rowStatus := store.Archived + require.NoError(t, s.store.UpdateMemo(context.Background(), &store.UpdateMemo{ + ID: memoID, + RowStatus: &rowStatus, + })) +} + +func (s *testMCPService) createAttachment(t *testing.T, creatorID int32, memoID *int32) *store.Attachment { + t.Helper() + + attachment, err := s.store.CreateAttachment(context.Background(), &store.Attachment{ + UID: shortuuid.New(), + CreatorID: creatorID, + Filename: "note.txt", + Type: "text/plain", + Size: 4, + StorageType: storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED, + Reference: "db://attachment/note.txt", + MemoID: memoID, + }) + require.NoError(t, err) + return attachment +} + +func withUser(ctx context.Context, userID int32) context.Context { + return context.WithValue(ctx, auth.UserIDContextKey, userID) +} + +func toolRequest(name string, arguments map[string]any) mcp.CallToolRequest { + return mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: name, + Arguments: arguments, + }, + } +} + +func firstText(t *testing.T, result *mcp.CallToolResult) string { + t.Helper() + require.NotEmpty(t, result.Content) + text, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok) + return text.Text +} + +func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + other := ts.createUser(t, "other") + + memo := ts.createMemo(t, owner.ID, store.Public, "archived") + ts.archiveMemo(t, memo.ID) + + ctx := withUser(context.Background(), other.ID) + result, err := ts.service.handleGetMemo(ctx, toolRequest("get_memo", map[string]any{ + "name": "memos/" + memo.UID, + })) + require.NoError(t, err) + require.True(t, result.IsError) + require.Contains(t, firstText(t, result), "permission denied") + + _, err = ts.service.handleReadMemoResource(ctx, mcp.ReadResourceRequest{ + Params: mcp.ReadResourceParams{ + URI: "memo://memos/" + memo.UID, + }, + }) + require.ErrorContains(t, err, "permission denied") +} + +func TestHandleListMemosArchivedOnlyReturnsCreatorMemos(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + other := ts.createUser(t, "other") + + ownerMemo := ts.createMemo(t, owner.ID, store.Public, "owner archived") + ts.archiveMemo(t, ownerMemo.ID) + otherMemo := ts.createMemo(t, other.ID, store.Public, "other archived") + ts.archiveMemo(t, otherMemo.ID) + + result, err := ts.service.handleListMemos(withUser(context.Background(), owner.ID), toolRequest("list_memos", map[string]any{ + "state": "ARCHIVED", + })) + require.NoError(t, err) + require.False(t, result.IsError) + + var payload struct { + Memos []memoJSON `json:"memos"` + } + require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &payload)) + require.Len(t, payload.Memos, 1) + require.Equal(t, "memos/"+ownerMemo.UID, payload.Memos[0].Name) + + anonResult, err := ts.service.handleListMemos(context.Background(), toolRequest("list_memos", map[string]any{ + "state": "ARCHIVED", + })) + require.NoError(t, err) + require.NoError(t, json.Unmarshal([]byte(firstText(t, anonResult)), &payload)) + require.Empty(t, payload.Memos) +} + +func TestHandleListMemoRelationsFiltersUnreadableTargets(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + privateUser := ts.createUser(t, "private-user") + publicUser := ts.createUser(t, "public-user") + + source := ts.createMemo(t, owner.ID, store.Public, "source") + privateTarget := ts.createMemo(t, privateUser.ID, store.Private, "private") + publicTarget := ts.createMemo(t, publicUser.ID, store.Public, "public") + + _, err := ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{ + MemoID: source.ID, + RelatedMemoID: privateTarget.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + _, err = ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{ + MemoID: source.ID, + RelatedMemoID: publicTarget.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + result, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{ + "name": "memos/" + source.UID, + })) + require.NoError(t, err) + require.False(t, result.IsError) + + var relations []relationJSON + require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &relations)) + require.Len(t, relations, 1) + require.Equal(t, "memos/"+publicTarget.UID, relations[0].RelatedMemo) + + denied, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{ + "name": "memos/" + privateTarget.UID, + })) + require.NoError(t, err) + require.True(t, denied.IsError) + require.Contains(t, firstText(t, denied), "permission denied") +} + +func TestHandleLinkAttachmentToMemoRequiresMemoOwnership(t *testing.T) { + ts := newTestMCPService(t) + attachmentOwner := ts.createUser(t, "attachment-owner") + memoOwner := ts.createUser(t, "memo-owner") + + attachment := ts.createAttachment(t, attachmentOwner.ID, nil) + memo := ts.createMemo(t, memoOwner.ID, store.Public, "target") + + result, err := ts.service.handleLinkAttachmentToMemo(withUser(context.Background(), attachmentOwner.ID), toolRequest("link_attachment_to_memo", map[string]any{ + "name": "attachments/" + attachment.UID, + "memo": "memos/" + memo.UID, + })) + require.NoError(t, err) + require.True(t, result.IsError) + require.Contains(t, firstText(t, result), "permission denied") +} + +func TestHandleGetAttachmentDeniesArchivedLinkedMemoToNonCreator(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + other := ts.createUser(t, "other") + + memo := ts.createMemo(t, owner.ID, store.Public, "memo") + ts.archiveMemo(t, memo.ID) + attachment := ts.createAttachment(t, owner.ID, &memo.ID) + + result, err := ts.service.handleGetAttachment(withUser(context.Background(), other.ID), toolRequest("get_attachment", map[string]any{ + "name": "attachments/" + attachment.UID, + })) + require.NoError(t, err) + require.True(t, result.IsError) + require.Contains(t, firstText(t, result), "permission denied") +} + +func TestIsAllowedOrigin(t *testing.T) { + ts := newTestMCPService(t) + + t.Run("allow missing origin", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil) + require.True(t, ts.service.isAllowedOrigin(req)) + }) + + t.Run("allow same origin as request host", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil) + req.Header.Set("Origin", "http://localhost:5230") + require.True(t, ts.service.isAllowedOrigin(req)) + }) + + t.Run("allow configured instance origin", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://127.0.0.1:5230/mcp", nil) + req.Host = "127.0.0.1:5230" + req.Header.Set("Origin", "https://notes.example.com") + require.True(t, ts.service.isAllowedOrigin(req)) + }) + + t.Run("reject cross origin", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil) + req.Header.Set("Origin", "https://evil.example.com") + require.False(t, ts.service.isAllowedOrigin(req)) + }) +} diff --git a/server/router/mcp/tools_attachment.go b/server/router/mcp/tools_attachment.go index 4bcd098a6..2e2b3f571 100644 --- a/server/router/mcp/tools_attachment.go +++ b/server/router/mcp/tools_attachment.go @@ -216,21 +216,8 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe return mcp.NewToolResultError("attachment not found"), nil } - // Check access: creator can always access; linked memo visibility applies otherwise. - if attachment.CreatorID != userID { - if attachment.MemoID != nil { - memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID}) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get linked memo: %v", err)), nil - } - if memo != nil { - if err := checkMemoAccess(memo, userID); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - } - } else { - return mcp.NewToolResultError("permission denied"), nil - } + if err := s.checkAttachmentAccess(ctx, attachment, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } result, err := storeAttachmentToJSON(ctx, s.store, attachment) @@ -302,6 +289,9 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal if memo == nil { return mcp.NewToolResultError("memo not found"), nil } + if err := checkMemoOwnership(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{ ID: attachment.ID, diff --git a/server/router/mcp/tools_memo.go b/server/router/mcp/tools_memo.go index 2fab7fbe5..2e106c805 100644 --- a/server/router/mcp/tools_memo.go +++ b/server/router/mcp/tools_memo.go @@ -168,33 +168,6 @@ func storeMemoToJSONWithUsernames(m *store.Memo, usernamesByID map[int32]string) return j, nil } -// checkMemoAccess returns an error if the caller cannot read memo. -// userID == 0 means anonymous. -func checkMemoAccess(memo *store.Memo, userID int32) error { - switch memo.Visibility { - case store.Protected: - if userID == 0 { - return errors.New("permission denied") - } - case store.Private: - if memo.CreatorID != userID { - return errors.New("permission denied") - } - default: - // store.Public and any unknown visibility: allow - } - return nil -} - -// applyVisibilityFilter restricts find to memos the caller may see. -func applyVisibilityFilter(find *store.FindMemo, userID int32) { - if userID == 0 { - find.VisibilityList = []store.Visibility{store.Public} - } else { - find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID)) - } -} - // parseMemoUID extracts the UID from a "memos/" resource name. func parseMemoUID(name string) (string, error) { uid, ok := strings.CutPrefix(name, "memos/") @@ -337,7 +310,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques Offset: &offset, OrderByPinned: req.GetBool("order_by_pinned", false), } - applyVisibilityFilter(find, userID) + applyVisibilityFilter(find, userID, rowStatus) if filter := req.GetString("filter", ""); filter != "" { find.Filters = append(find.Filters, filter) } @@ -465,8 +438,8 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque if memo == nil { return mcp.NewToolResultError("memo not found"), nil } - if memo.CreatorID != userID { - return mcp.NewToolResultError("permission denied"), nil + if err := checkMemoOwnership(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } update := &store.UpdateMemo{ID: memo.ID} @@ -533,8 +506,8 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque if memo == nil { return mcp.NewToolResultError("memo not found"), nil } - if memo.CreatorID != userID { - return mcp.NewToolResultError("permission denied"), nil + if err := checkMemoOwnership(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { @@ -561,7 +534,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ Offset: &zero, Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)}, } - applyVisibilityFilter(find, userID) + applyVisibilityFilter(find, userID, find.RowStatus) memos, err := s.store.ListMemos(ctx, find) if err != nil { diff --git a/server/router/mcp/tools_relation.go b/server/router/mcp/tools_relation.go index 773f63eb3..6a7e886ab 100644 --- a/server/router/mcp/tools_relation.go +++ b/server/router/mcp/tools_relation.go @@ -7,6 +7,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) @@ -40,6 +41,8 @@ func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) { } func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := auth.GetUserID(ctx) + uid, err := parseMemoUID(req.GetString("name", "")) if err != nil { return mcp.NewToolResultError(err.Error()), nil @@ -52,6 +55,9 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo if memo == nil { return mcp.NewToolResultError("memo not found"), nil } + if err := checkMemoAccess(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } find := &store.FindMemoRelation{ MemoIDList: []int32{memo.ID}, @@ -85,21 +91,24 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memos: %v", err)), nil } - uidByID := make(map[int32]string, len(memos)) + memoByID := make(map[int32]*store.Memo, len(memos)) for _, m := range memos { - uidByID[m.ID] = m.UID + memoByID[m.ID] = m } results := make([]relationJSON, 0, len(relations)) for _, r := range relations { - memoUID, ok1 := uidByID[r.MemoID] - relatedUID, ok2 := uidByID[r.RelatedMemoID] + srcMemo, ok1 := memoByID[r.MemoID] + relatedMemo, ok2 := memoByID[r.RelatedMemoID] if !ok1 || !ok2 { continue } + if checkMemoAccess(srcMemo, userID) != nil || checkMemoAccess(relatedMemo, userID) != nil { + continue + } results = append(results, relationJSON{ - Memo: "memos/" + memoUID, - RelatedMemo: "memos/" + relatedUID, + Memo: "memos/" + srcMemo.UID, + RelatedMemo: "memos/" + relatedMemo.UID, Type: string(r.Type), }) } @@ -133,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if srcMemo.CreatorID != userID { + if err := checkMemoOwnership(srcMemo, userID); err != nil { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } @@ -144,6 +153,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if dstMemo == nil { return mcp.NewToolResultError("related memo not found"), nil } + if err := checkMemoAccess(dstMemo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ MemoID: srcMemo.ID, @@ -187,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if srcMemo.CreatorID != userID { + if err := checkMemoOwnership(srcMemo, userID); err != nil { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } @@ -198,6 +210,9 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT if dstMemo == nil { return mcp.NewToolResultError("related memo not found"), nil } + if err := checkMemoAccess(dstMemo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } refType := store.MemoRelationReference if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ diff --git a/server/router/mcp/tools_tag.go b/server/router/mcp/tools_tag.go index 55fabd632..ded3c5849 100644 --- a/server/router/mcp/tools_tag.go +++ b/server/router/mcp/tools_tag.go @@ -32,7 +32,7 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) ExcludeContent: true, RowStatus: &rowStatus, } - applyVisibilityFilter(find, userID) + applyVisibilityFilter(find, userID, find.RowStatus) memos, err := s.store.ListMemos(ctx, find) if err != nil { From 0e89407ee91deda87e0df7464a89d26d1e9a88b3 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 31 Mar 2026 08:10:49 +0800 Subject: [PATCH 5/8] fix(filter): enforce CEL syntax semantics Reject non-standard truthy numeric expressions in filters and document the parser as a supported subset of standard CEL syntax. - remove legacy filter rewrites - support standard equality in tag exists predicates - add regression coverage for accepted and rejected expressions --- plugin/filter/README.md | 8 ++-- plugin/filter/engine.go | 73 ----------------------------- plugin/filter/engine_test.go | 39 +++++++++++++++ plugin/filter/ir.go | 7 +++ plugin/filter/parser.go | 48 +++++++++++++++---- plugin/filter/render.go | 18 +++++++ server/router/mcp/README.md | 2 +- server/router/mcp/access.go | 4 ++ server/router/mcp/tools_memo.go | 2 +- server/router/mcp/tools_relation.go | 4 +- store/test/memo_filter_test.go | 25 ++++++++++ 11 files changed, 141 insertions(+), 89 deletions(-) create mode 100644 plugin/filter/engine_test.go diff --git a/plugin/filter/README.md b/plugin/filter/README.md index 35961615f..ac1aec4b6 100644 --- a/plugin/filter/README.md +++ b/plugin/filter/README.md @@ -1,12 +1,14 @@ # Memo Filter Engine -This package houses the memo-only filter engine that turns CEL expressions into -SQL fragments. The engine follows a three phase pipeline inspired by systems +This package houses the memo-only filter engine that turns standard CEL syntax +into SQL fragments for the subset of expressions supported by the memo schema. +The engine follows a three phase pipeline inspired by systems such as Calcite or Prisma: 1. **Parsing** – CEL expressions are parsed with `cel-go` and validated against the memo-specific environment declared in `schema.go`. Only fields that - exist in the schema can surface in the filter. + exist in the schema can surface in the filter, and non-standard legacy + coercions are rejected. 2. **Normalization** – the raw CEL AST is converted into an intermediate representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of conditions (logical operators, comparisons, list membership, etc.). This diff --git a/plugin/filter/engine.go b/plugin/filter/engine.go index c9fcfba7f..9dab7a0ba 100644 --- a/plugin/filter/engine.go +++ b/plugin/filter/engine.go @@ -2,7 +2,6 @@ package filter import ( "context" - "fmt" "strings" "sync" @@ -45,8 +44,6 @@ func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) { return nil, errors.New("filter expression is empty") } - filter = normalizeLegacyFilter(filter) - ast, issues := e.env.Compile(filter) if issues != nil && issues.Err() != nil { return nil, errors.Wrap(issues.Err(), "failed to compile filter") @@ -119,73 +116,3 @@ func DefaultAttachmentEngine() (*Engine, error) { }) return defaultAttachmentInst, defaultAttachmentErr } - -func normalizeLegacyFilter(expr string) string { - expr = rewriteNumericLogicalOperand(expr, "&&") - expr = rewriteNumericLogicalOperand(expr, "||") - return expr -} - -func rewriteNumericLogicalOperand(expr, op string) string { - var builder strings.Builder - n := len(expr) - i := 0 - var inQuote rune - - for i < n { - ch := expr[i] - - if inQuote != 0 { - builder.WriteByte(ch) - if ch == '\\' && i+1 < n { - builder.WriteByte(expr[i+1]) - i += 2 - continue - } - if ch == byte(inQuote) { - inQuote = 0 - } - i++ - continue - } - - if ch == '\'' || ch == '"' { - inQuote = rune(ch) - builder.WriteByte(ch) - i++ - continue - } - - if strings.HasPrefix(expr[i:], op) { - builder.WriteString(op) - i += len(op) - - // Preserve whitespace following the operator. - wsStart := i - for i < n && (expr[i] == ' ' || expr[i] == '\t') { - i++ - } - builder.WriteString(expr[wsStart:i]) - - signStart := i - if i < n && (expr[i] == '+' || expr[i] == '-') { - i++ - } - for i < n && expr[i] >= '0' && expr[i] <= '9' { - i++ - } - if i > signStart { - numLiteral := expr[signStart:i] - fmt.Fprintf(&builder, "(%s != 0)", numLiteral) - } else { - builder.WriteString(expr[signStart:i]) - } - continue - } - - builder.WriteByte(ch) - i++ - } - - return builder.String() -} diff --git a/plugin/filter/engine_test.go b/plugin/filter/engine_test.go new file mode 100644 index 000000000..f9e72c224 --- /dev/null +++ b/plugin/filter/engine_test.go @@ -0,0 +1,39 @@ +package filter + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCompileAcceptsStandardTagEqualityPredicate(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(NewSchema()) + require.NoError(t, err) + + _, err = engine.Compile(context.Background(), `tags.exists(t, t == "1231")`) + require.NoError(t, err) +} + +func TestCompileRejectsLegacyNumericLogicalOperand(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(NewSchema()) + require.NoError(t, err) + + _, err = engine.Compile(context.Background(), `pinned && 1`) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to compile filter") +} + +func TestCompileRejectsNonBooleanTopLevelConstant(t *testing.T) { + t.Parallel() + + engine, err := NewEngine(NewSchema()) + require.NoError(t, err) + + _, err = engine.Compile(context.Background(), `1`) + require.EqualError(t, err, "filter must evaluate to a boolean value") +} diff --git a/plugin/filter/ir.go b/plugin/filter/ir.go index 10cb13df1..b5a995dda 100644 --- a/plugin/filter/ir.go +++ b/plugin/filter/ir.go @@ -157,3 +157,10 @@ type ContainsPredicate struct { } func (*ContainsPredicate) isPredicateExpr() {} + +// EqualsPredicate represents t == "value". +type EqualsPredicate struct { + Value string +} + +func (*EqualsPredicate) isPredicateExpr() {} diff --git a/plugin/filter/parser.go b/plugin/filter/parser.go index 36e52d1db..2aff1074e 100644 --- a/plugin/filter/parser.go +++ b/plugin/filter/parser.go @@ -16,16 +16,10 @@ func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) { if err != nil { return nil, err } - switch v := val.(type) { - case bool: + if v, ok := val.(bool); ok { return &ConstantCondition{Value: v}, nil - case int64: - return &ConstantCondition{Value: v != 0}, nil - case float64: - return &ConstantCondition{Value: v != 0}, nil - default: - return nil, errors.New("filter must evaluate to a boolean value") } + return nil, errors.New("filter must evaluate to a boolean value") case *exprv1.Expr_IdentExpr: name := v.IdentExpr.GetName() field, ok := schema.Field(name) @@ -504,6 +498,8 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, // Handle different predicate functions switch predicateCall.Function { + case "_==_": + return buildEqualsPredicate(predicateCall, comp.IterVar) case "startsWith": return buildStartsWithPredicate(predicateCall, comp.IterVar) case "endsWith": @@ -511,10 +507,44 @@ func extractPredicate(comp *exprv1.Expr_Comprehension, _ Schema) (PredicateExpr, case "contains": return buildContainsPredicate(predicateCall, comp.IterVar) default: - return nil, errors.Errorf("unsupported predicate function %q in comprehension (supported: startsWith, endsWith, contains)", predicateCall.Function) + return nil, errors.Errorf(`unsupported predicate function %q in comprehension (supported: ==, startsWith, endsWith, contains)`, predicateCall.Function) } } +// buildEqualsPredicate extracts the value from t == "value". +func buildEqualsPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { + if len(call.Args) != 2 { + return nil, errors.New("equality predicate expects exactly two arguments") + } + + var constExpr *exprv1.Expr + switch { + case isIterVarExpr(call.Args[0], iterVar): + constExpr = call.Args[1] + case isIterVarExpr(call.Args[1], iterVar): + constExpr = call.Args[0] + default: + return nil, errors.Errorf("equality predicate must compare against the iteration variable %q", iterVar) + } + + value, err := getConstValue(constExpr) + if err != nil { + return nil, errors.Wrap(err, "equality argument must be a constant string") + } + + valueStr, ok := value.(string) + if !ok { + return nil, errors.New("equality argument must be a string") + } + + return &EqualsPredicate{Value: valueStr}, nil +} + +func isIterVarExpr(expr *exprv1.Expr, iterVar string) bool { + target := expr.GetIdentExpr() + return target != nil && target.GetName() == iterVar +} + // buildStartsWithPredicate extracts the pattern from t.startsWith("prefix"). func buildStartsWithPredicate(call *exprv1.Expr_Call, iterVar string) (PredicateExpr, error) { // Verify the target is the iteration variable diff --git a/plugin/filter/render.go b/plugin/filter/render.go index c91096a7b..39eaaec01 100644 --- a/plugin/filter/render.go +++ b/plugin/filter/render.go @@ -480,6 +480,8 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re // Render based on predicate type switch pred := cond.Predicate.(type) { + case *EqualsPredicate: + return r.renderTagEquals(field, pred.Value, cond.Kind) case *StartsWithPredicate: return r.renderTagStartsWith(field, pred.Prefix, cond.Kind) case *EndsWithPredicate: @@ -491,6 +493,22 @@ func (r *renderer) renderListComprehension(cond *ListComprehensionCondition) (re } } +// renderTagEquals generates SQL for tags.exists(t, t == "value"). +func (r *renderer) renderTagEquals(field Field, value string, _ ComprehensionKind) (renderResult, error) { + arrayExpr := jsonArrayExpr(r.dialect, field) + + switch r.dialect { + case DialectSQLite, DialectMySQL: + exactMatch := r.buildJSONArrayLike(arrayExpr, fmt.Sprintf(`%%"%s"%%`, value)) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil + case DialectPostgres: + exactMatch := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", arrayExpr, r.addArg(fmt.Sprintf(`"%s"`, value))) + return renderResult{sql: r.wrapWithNullCheck(arrayExpr, exactMatch)}, nil + default: + return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect) + } +} + // renderTagStartsWith generates SQL for tags.exists(t, t.startsWith("prefix")). func (r *renderer) renderTagStartsWith(field Field, prefix string, _ ComprehensionKind) (renderResult, error) { arrayExpr := jsonArrayExpr(r.dialect, field) diff --git a/server/router/mcp/README.md b/server/router/mcp/README.md index 78feb5732..0af991fc0 100644 --- a/server/router/mcp/README.md +++ b/server/router/mcp/README.md @@ -48,7 +48,7 @@ For Streamable HTTP safety, requests with an `Origin` header must be same-origin | Tool | Description | Required params | Optional params | |---|---|---|---| -| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (CEL) | +| `list_memos` | List memos | — | `page_size`, `page`, `state`, `order_by_pinned`, `filter` (supported subset of standard CEL syntax) | | `get_memo` | Get a single memo | `name` | — | | `search_memos` | Full-text search | `query` | — | | `create_memo` | Create a memo | `content` | `visibility` | diff --git a/server/router/mcp/access.go b/server/router/mcp/access.go index 0e950b228..eadc11f83 100644 --- a/server/router/mcp/access.go +++ b/server/router/mcp/access.go @@ -41,6 +41,10 @@ func checkMemoOwnership(memo *store.Memo, userID int32) error { return nil } +func hasMemoOwnership(memo *store.Memo, userID int32) bool { + return memo.CreatorID == userID +} + // applyVisibilityFilter restricts find to memos the caller may see. func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) { if rowStatus != nil && *rowStatus == store.Archived { diff --git a/server/router/mcp/tools_memo.go b/server/router/mcp/tools_memo.go index 2e106c805..47e8a2298 100644 --- a/server/router/mcp/tools_memo.go +++ b/server/router/mcp/tools_memo.go @@ -223,7 +223,7 @@ func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) { mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"), ), mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")), - mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), + mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), ), s.handleListMemos) mcpSrv.AddTool(mcp.NewTool("get_memo", diff --git a/server/router/mcp/tools_relation.go b/server/router/mcp/tools_relation.go index 6a7e886ab..127bb16fe 100644 --- a/server/router/mcp/tools_relation.go +++ b/server/router/mcp/tools_relation.go @@ -142,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if err := checkMemoOwnership(srcMemo, userID); err != nil { + if !hasMemoOwnership(srcMemo, userID) { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } @@ -199,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if err := checkMemoOwnership(srcMemo, userID); err != nil { + if !hasMemoOwnership(srcMemo, userID) { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } diff --git a/store/test/memo_filter_test.go b/store/test/memo_filter_test.go index aaa25488d..09f49854c 100644 --- a/store/test/memo_filter_test.go +++ b/store/test/memo_filter_test.go @@ -730,6 +730,31 @@ func TestMemoFilterTagsExistsContains(t *testing.T) { require.Len(t, memos, 1, "Should find 1 non-todo memo") } +func TestMemoFilterTagsExistsEquals(t *testing.T) { + t.Parallel() + tc := NewMemoFilterTestContext(t) + defer tc.Close() + + tc.CreateMemo(NewMemoBuilder("memo-1231", tc.User.ID). + Content("Memo with exact numeric tag"). + Tags("1231", "project")) + + tc.CreateMemo(NewMemoBuilder("memo-1231-suffix", tc.User.ID). + Content("Memo with related tag"). + Tags("tag/1231", "other")) + + tc.CreateMemo(NewMemoBuilder("memo-other", tc.User.ID). + Content("Memo with different tag"). + Tags("9999")) + + memos := tc.ListWithFilter(`tags.exists(t, t == "1231")`) + require.Len(t, memos, 1, "Should find only the memo with exact matching tag") + require.Equal(t, "memo-1231", memos[0].UID) + + memos = tc.ListWithFilter(`!tags.exists(t, t == "1231")`) + require.Len(t, memos, 2, "Should exclude only the memo with exact matching tag") +} + func TestMemoFilterTagsExistsEndsWith(t *testing.T) { t.Parallel() tc := NewMemoFilterTestContext(t) From 201c8a8ea96f506b52ce187528a8b8dc691f827b Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 31 Mar 2026 08:34:36 +0800 Subject: [PATCH 6/8] chore: add rc release handling --- .github/workflows/release.yml | 40 ++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7e946a7b9..298266b1d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -24,6 +24,8 @@ jobs: outputs: version: ${{ steps.version.outputs.version }} tag: ${{ steps.version.outputs.tag }} + major_minor: ${{ steps.version.outputs.major_minor }} + is_prerelease: ${{ steps.version.outputs.is_prerelease }} steps: - name: Extract version id: version @@ -34,11 +36,27 @@ jobs: if [ "$EVENT_NAME" = "workflow_dispatch" ]; then echo "tag=" >> "$GITHUB_OUTPUT" echo "version=manual-${GITHUB_SHA::7}" >> "$GITHUB_OUTPUT" + echo "major_minor=" >> "$GITHUB_OUTPUT" + echo "is_prerelease=false" >> "$GITHUB_OUTPUT" exit 0 fi + if [[ ! "$REF_NAME" =~ ^v([0-9]+\.[0-9]+\.[0-9]+)(-rc\.[0-9]+)?$ ]]; then + echo "Unsupported release tag format: $REF_NAME" >&2 + exit 1 + fi + + version="${BASH_REMATCH[1]}${BASH_REMATCH[2]}" + major_minor="${BASH_REMATCH[1]%.*}" + is_prerelease=false + if [ -n "${BASH_REMATCH[2]}" ]; then + is_prerelease=true + fi + echo "tag=${REF_NAME}" >> "$GITHUB_OUTPUT" - echo "version=${REF_NAME#v}" >> "$GITHUB_OUTPUT" + echo "version=${version}" >> "$GITHUB_OUTPUT" + echo "major_minor=${major_minor}" >> "$GITHUB_OUTPUT" + echo "is_prerelease=${is_prerelease}" >> "$GITHUB_OUTPUT" build-frontend: name: Build Frontend @@ -226,6 +244,7 @@ jobs: tag_name: ${{ needs.prepare.outputs.tag }} name: ${{ needs.prepare.outputs.tag }} generate_release_notes: true + prerelease: ${{ needs.prepare.outputs.is_prerelease == 'true' }} files: artifacts/* build-push: @@ -301,7 +320,7 @@ jobs: retention-days: 1 merge-images: - name: Publish Stable Image Tags + name: Publish Release Image Tags needs: [prepare, build-push] if: github.event_name != 'workflow_dispatch' runs-on: ubuntu-latest @@ -336,17 +355,28 @@ jobs: working-directory: /tmp/digests run: | version="${{ needs.prepare.outputs.version }}" - major_minor=$(echo "$version" | cut -d. -f1,2) + if [ "${{ needs.prepare.outputs.is_prerelease }}" = "true" ]; then + docker buildx imagetools create \ + -t "neosmemo/memos:${version}" \ + -t "ghcr.io/usememos/memos:${version}" \ + $(printf 'neosmemo/memos@sha256:%s ' *) + exit 0 + fi + docker buildx imagetools create \ -t "neosmemo/memos:${version}" \ - -t "neosmemo/memos:${major_minor}" \ + -t "neosmemo/memos:${{ needs.prepare.outputs.major_minor }}" \ -t "neosmemo/memos:stable" \ -t "ghcr.io/usememos/memos:${version}" \ - -t "ghcr.io/usememos/memos:${major_minor}" \ + -t "ghcr.io/usememos/memos:${{ needs.prepare.outputs.major_minor }}" \ -t "ghcr.io/usememos/memos:stable" \ $(printf 'neosmemo/memos@sha256:%s ' *) - name: Inspect images run: | docker buildx imagetools inspect neosmemo/memos:${{ needs.prepare.outputs.version }} + if [ "${{ needs.prepare.outputs.is_prerelease }}" = "true" ]; then + exit 0 + fi + docker buildx imagetools inspect neosmemo/memos:stable From 1921b57662c2129d179930f71e5c42caf9070a19 Mon Sep 17 00:00:00 2001 From: memoclaw Date: Tue, 31 Mar 2026 21:38:55 +0800 Subject: [PATCH 7/8] fix(tags): allow blur-only tag metadata (#5800) Co-authored-by: memoclaw <265580040+memoclaw@users.noreply.github.com> --- proto/api/v1/instance_service.proto | 3 +- proto/gen/api/v1/instance_service.pb.go | 3 +- proto/gen/openapi.yaml | 4 +- proto/gen/store/instance_setting.pb.go | 3 +- proto/store/instance_setting.proto | 3 +- server/router/api/v1/instance_service.go | 9 ++- .../api/v1/test/instance_service_test.go | 28 ++++++++ store/test/instance_setting_test.go | 28 ++++++++ web/src/components/Settings/TagsSection.tsx | 64 ++++++++++--------- web/src/locales/en.json | 5 +- .../types/proto/api/v1/instance_service_pb.ts | 3 +- 11 files changed, 111 insertions(+), 42 deletions(-) diff --git a/proto/api/v1/instance_service.proto b/proto/api/v1/instance_service.proto index 6dc8f9fb2..1f500b3e3 100644 --- a/proto/api/v1/instance_service.proto +++ b/proto/api/v1/instance_service.proto @@ -167,7 +167,8 @@ message InstanceSetting { // Metadata for a tag. message TagMetadata { - // Background color for the tag label. + // Optional background color for the tag label. + // When unset, the default tag color is used. google.type.Color background_color = 1; // Whether memos with this tag should have their content blurred. bool blur_content = 2; diff --git a/proto/gen/api/v1/instance_service.pb.go b/proto/gen/api/v1/instance_service.pb.go index 37d48328c..be77651a0 100644 --- a/proto/gen/api/v1/instance_service.pb.go +++ b/proto/gen/api/v1/instance_service.pb.go @@ -759,7 +759,8 @@ func (x *InstanceSetting_MemoRelatedSetting) GetReactions() []string { // Metadata for a tag. type InstanceSetting_TagMetadata struct { state protoimpl.MessageState `protogen:"open.v1"` - // Background color for the tag label. + // Optional background color for the tag label. + // When unset, the default tag color is used. BackgroundColor *color.Color `protobuf:"bytes,1,opt,name=background_color,json=backgroundColor,proto3" json:"background_color,omitempty"` // Whether memos with this tag should have their content blurred. BlurContent bool `protobuf:"varint,2,opt,name=blur_content,json=blurContent,proto3" json:"blur_content,omitempty"` diff --git a/proto/gen/openapi.yaml b/proto/gen/openapi.yaml index 6a6901410..a315fcf95 100644 --- a/proto/gen/openapi.yaml +++ b/proto/gen/openapi.yaml @@ -2396,7 +2396,9 @@ components: backgroundColor: allOf: - $ref: '#/components/schemas/Color' - description: Background color for the tag label. + description: |- + Optional background color for the tag label. + When unset, the default tag color is used. blurContent: type: boolean description: Whether memos with this tag should have their content blurred. diff --git a/proto/gen/store/instance_setting.pb.go b/proto/gen/store/instance_setting.pb.go index d51be76ae..b12889b40 100644 --- a/proto/gen/store/instance_setting.pb.go +++ b/proto/gen/store/instance_setting.pb.go @@ -754,7 +754,8 @@ func (x *InstanceMemoRelatedSetting) GetReactions() []string { type InstanceTagMetadata struct { state protoimpl.MessageState `protogen:"open.v1"` - // Background color for the tag label. + // Optional background color for the tag label. + // When unset, the default tag color is used. BackgroundColor *color.Color `protobuf:"bytes,1,opt,name=background_color,json=backgroundColor,proto3" json:"background_color,omitempty"` // Whether memos with this tag should have their content blurred. BlurContent bool `protobuf:"varint,2,opt,name=blur_content,json=blurContent,proto3" json:"blur_content,omitempty"` diff --git a/proto/store/instance_setting.proto b/proto/store/instance_setting.proto index a6701bd7a..2c7848d3c 100644 --- a/proto/store/instance_setting.proto +++ b/proto/store/instance_setting.proto @@ -111,7 +111,8 @@ message InstanceMemoRelatedSetting { } message InstanceTagMetadata { - // Background color for the tag label. + // Optional background color for the tag label. + // When unset, the default tag color is used. google.type.Color background_color = 1; // Whether memos with this tag should have their content blurred. bool blur_content = 2; diff --git a/server/router/api/v1/instance_service.go b/server/router/api/v1/instance_service.go index 4be6986b7..f6e6d519c 100644 --- a/server/router/api/v1/instance_service.go +++ b/server/router/api/v1/instance_service.go @@ -423,11 +423,10 @@ func validateInstanceTagsSetting(setting *v1pb.InstanceSetting_TagsSetting) erro if metadata == nil { return errors.Errorf("tag metadata is required for %q", tag) } - if metadata.GetBackgroundColor() == nil { - return errors.Errorf("background_color is required for %q", tag) - } - if err := validateInstanceColor(metadata.GetBackgroundColor()); err != nil { - return errors.Wrapf(err, "background_color for %q", tag) + if metadata.GetBackgroundColor() != nil { + if err := validateInstanceColor(metadata.GetBackgroundColor()); err != nil { + return errors.Wrapf(err, "background_color for %q", tag) + } } } return nil diff --git a/server/router/api/v1/test/instance_service_test.go b/server/router/api/v1/test/instance_service_test.go index bd7e3f288..eac97b58e 100644 --- a/server/router/api/v1/test/instance_service_test.go +++ b/server/router/api/v1/test/instance_service_test.go @@ -318,6 +318,34 @@ func TestUpdateInstanceSetting(t *testing.T) { require.Contains(t, err.Error(), "invalid instance setting") }) + t.Run("UpdateInstanceSetting - tags setting without color", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + resp, err := ts.Service.UpdateInstanceSetting(ts.CreateUserContext(ctx, hostUser.ID), &v1pb.UpdateInstanceSettingRequest{ + Setting: &v1pb.InstanceSetting{ + Name: "instance/settings/TAGS", + Value: &v1pb.InstanceSetting_TagsSetting_{ + TagsSetting: &v1pb.InstanceSetting_TagsSetting{ + Tags: map[string]*v1pb.InstanceSetting_TagMetadata{ + "spoiler": { + BlurContent: true, + }, + }, + }, + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, resp.GetTagsSetting()) + require.Contains(t, resp.GetTagsSetting().GetTags(), "spoiler") + require.Nil(t, resp.GetTagsSetting().GetTags()["spoiler"].GetBackgroundColor()) + require.True(t, resp.GetTagsSetting().GetTags()["spoiler"].GetBlurContent()) + }) + t.Run("UpdateInstanceSetting - notification setting password is write-only", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() diff --git a/store/test/instance_setting_test.go b/store/test/instance_setting_test.go index 1451234ec..bf63b1fe3 100644 --- a/store/test/instance_setting_test.go +++ b/store/test/instance_setting_test.go @@ -257,6 +257,34 @@ func TestInstanceSettingTagsSetting(t *testing.T) { ts.Close() } +func TestInstanceSettingTagsSettingWithoutColor(t *testing.T) { + t.Parallel() + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + _, err := ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_TAGS, + Value: &storepb.InstanceSetting_TagsSetting{ + TagsSetting: &storepb.InstanceTagsSetting{ + Tags: map[string]*storepb.InstanceTagMetadata{ + "spoiler": { + BlurContent: true, + }, + }, + }, + }, + }) + require.NoError(t, err) + + tagsSetting, err := ts.GetInstanceTagsSetting(ctx) + require.NoError(t, err) + require.Contains(t, tagsSetting.Tags, "spoiler") + require.Nil(t, tagsSetting.Tags["spoiler"].GetBackgroundColor()) + require.True(t, tagsSetting.Tags["spoiler"].GetBlurContent()) + + ts.Close() +} + func TestInstanceSettingNotificationSetting(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/web/src/components/Settings/TagsSection.tsx b/web/src/components/Settings/TagsSection.tsx index 8c6ef8b68..ab368fedf 100644 --- a/web/src/components/Settings/TagsSection.tsx +++ b/web/src/components/Settings/TagsSection.tsx @@ -22,8 +22,7 @@ import SettingGroup from "./SettingGroup"; import SettingSection from "./SettingSection"; import SettingTable from "./SettingTable"; -// Fallback to white when no color is stored. -const tagColorToHex = (color?: { red?: number; green?: number; blue?: number }): string => colorToHex(color) ?? "#ffffff"; +const DEFAULT_TAG_COLOR = "#ffffff"; // Converts a CSS hex string to a google.type.Color message. const hexToColor = (hex: string) => @@ -34,10 +33,18 @@ const hexToColor = (hex: string) => }); interface LocalTagMeta { - color: string; + color?: string; blur: boolean; } +const toLocalTagMeta = (meta: { + backgroundColor?: { red?: number; green?: number; blue?: number }; + blurContent: boolean; +}): LocalTagMeta => ({ + color: colorToHex(meta.backgroundColor), + blur: meta.blurContent, +}); + const TagsSection = () => { const t = useTranslate(); const { tagsSetting: originalSetting, updateSetting, fetchSetting } = useInstance(); @@ -45,28 +52,16 @@ const TagsSection = () => { // Local state: map of tagName → { color, blur } for editing. const [localTags, setLocalTags] = useState>(() => - Object.fromEntries( - Object.entries(originalSetting.tags).map(([name, meta]) => [ - name, - { color: tagColorToHex(meta.backgroundColor), blur: meta.blurContent }, - ]), - ), + Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])), ); const [newTagName, setNewTagName] = useState(""); - const [newTagColor, setNewTagColor] = useState("#ffffff"); + const [newTagColor, setNewTagColor] = useState(undefined); const [newTagBlur, setNewTagBlur] = useState(false); // Sync local state when the fetched setting arrives (the fetch is async and // completes after mount, so localTags would be empty without this sync). useEffect(() => { - setLocalTags( - Object.fromEntries( - Object.entries(originalSetting.tags).map(([name, meta]) => [ - name, - { color: tagColorToHex(meta.backgroundColor), blur: meta.blurContent }, - ]), - ), - ); + setLocalTags(Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)]))); }, [originalSetting.tags]); // All known tag names: union of saved entries and tags used in memos. @@ -85,13 +80,7 @@ const TagsSection = () => { ); const originalMetaMap = useMemo( - () => - Object.fromEntries( - Object.entries(originalSetting.tags).map(([name, meta]) => [ - name, - { color: tagColorToHex(meta.backgroundColor), blur: meta.blurContent }, - ]), - ), + () => Object.fromEntries(Object.entries(originalSetting.tags).map(([name, meta]) => [name, toLocalTagMeta(meta)])), [originalSetting.tags], ); const hasChanges = !isEqual(localTags, originalMetaMap); @@ -104,6 +93,10 @@ const TagsSection = () => { setLocalTags((prev) => ({ ...prev, [tagName]: { ...prev[tagName], blur } })); }; + const handleClearColor = (tagName: string) => { + setLocalTags((prev) => ({ ...prev, [tagName]: { ...prev[tagName], color: undefined } })); + }; + const handleRemoveTag = (tagName: string) => { setLocalTags((prev) => { const next = { ...prev }; @@ -125,7 +118,7 @@ const TagsSection = () => { } setLocalTags((prev) => ({ ...prev, [name]: { color: newTagColor, blur: newTagBlur } })); setNewTagName(""); - setNewTagColor("#ffffff"); + setNewTagColor(undefined); setNewTagBlur(false); }; @@ -134,7 +127,10 @@ const TagsSection = () => { const tags = Object.fromEntries( Object.entries(localTags).map(([name, meta]) => [ name, - create(InstanceSetting_TagMetadataSchema, { backgroundColor: hexToColor(meta.color), blurContent: meta.blur }), + create(InstanceSetting_TagMetadataSchema, { + blurContent: meta.blur, + ...(meta.color ? { backgroundColor: hexToColor(meta.color) } : {}), + }), ]), ); await updateSetting( @@ -171,9 +167,15 @@ const TagsSection = () => { handleColorChange(row.name, e.target.value)} /> + + {!localTags[row.name].color && ( + {t("setting.tags.using-default-color")} + )} ), }, @@ -224,9 +226,12 @@ const TagsSection = () => { setNewTagColor(e.target.value)} /> +