From d3f6e8ee31e7c1942ac99f71d5ca30c3ffa410b1 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 31 Mar 2026 00:12:28 +0800 Subject: [PATCH] chore: harden MCP access control and origin validation --- server/router/mcp/README.md | 20 +- server/router/mcp/access.go | 113 +++++++++++ server/router/mcp/mcp.go | 18 +- server/router/mcp/mcp_test.go | 275 ++++++++++++++++++++++++++ server/router/mcp/tools_attachment.go | 20 +- server/router/mcp/tools_memo.go | 39 +--- server/router/mcp/tools_relation.go | 31 ++- server/router/mcp/tools_tag.go | 2 +- 8 files changed, 452 insertions(+), 66 deletions(-) create mode 100644 server/router/mcp/access.go create mode 100644 server/router/mcp/mcp_test.go diff --git a/server/router/mcp/README.md b/server/router/mcp/README.md index 86cb16b1e..78feb5732 100644 --- a/server/router/mcp/README.md +++ b/server/router/mcp/README.md @@ -7,6 +7,7 @@ This package implements a [Model Context Protocol (MCP)](https://modelcontextpro ``` POST /mcp (tool calls, initialize) GET /mcp (optional SSE stream for server-to-client messages) +DELETE /mcp (optional session termination) ``` Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26). @@ -24,13 +25,22 @@ The server advertises the following MCP capabilities: ## Authentication -Every request must include a Personal Access Token (PAT): +Public reads can be used without authentication. Personal Access Tokens (PATs) or short-lived JWT session tokens are required for: + +- Reading non-public memos or attachments +- Any tool that mutates data + +When authenticating, send a Bearer token: ``` Authorization: Bearer ``` -PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests without a valid token receive `HTTP 401`. +PATs are long-lived tokens created in Settings → My Account → Access Tokens. Short-lived JWT session tokens are also accepted. Requests with an invalid token receive `HTTP 401`. + +## Origin Validation + +For Streamable HTTP safety, requests with an `Origin` header must be same-origin with the current request host or match the configured `instance-url`. Requests without an `Origin` header, such as desktop MCP clients and CLI tools, are allowed. ## Tools @@ -60,15 +70,15 @@ PATs are long-lived tokens created in Settings → My Account → Access Tokens. | `list_attachments` | List user's attachments | — | `page_size`, `page`, `memo` | | `get_attachment` | Get attachment metadata | `name` | — | | `delete_attachment` | Delete an attachment | `name` | — | -| `link_attachment_to_memo` | Link attachment to memo | `name`, `memo` | — | +| `link_attachment_to_memo` | Link attachment to a memo you own | `name`, `memo` | — | ### Relation Tools | Tool | Description | Required params | Optional params | |---|---|---|---| | `list_memo_relations` | List relations (refs + comments) | `name` | `type` | -| `create_memo_relation` | Create a reference relation | `name`, `related_memo` | — | -| `delete_memo_relation` | Delete a reference relation | `name`, `related_memo` | — | +| `create_memo_relation` | Create a reference relation from a memo you own to a memo you can read | `name`, `related_memo` | — | +| `delete_memo_relation` | Delete a reference relation from a memo you own | `name`, `related_memo` | — | ### Reaction Tools diff --git a/server/router/mcp/access.go b/server/router/mcp/access.go new file mode 100644 index 000000000..0e950b228 --- /dev/null +++ b/server/router/mcp/access.go @@ -0,0 +1,113 @@ +package mcp + +import ( + "context" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +// checkMemoAccess returns an error if the caller cannot read the memo. +// userID == 0 means anonymous. +func checkMemoAccess(memo *store.Memo, userID int32) error { + if memo.RowStatus == store.Archived && memo.CreatorID != userID { + return errors.New("permission denied") + } + + switch memo.Visibility { + case store.Protected: + if userID == 0 { + return errors.New("permission denied") + } + case store.Private: + if memo.CreatorID != userID { + return errors.New("permission denied") + } + default: + // store.Public and any unknown visibility: allow. + } + return nil +} + +func checkMemoOwnership(memo *store.Memo, userID int32) error { + if memo.CreatorID != userID { + return errors.New("permission denied") + } + return nil +} + +// applyVisibilityFilter restricts find to memos the caller may see. +func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) { + if rowStatus != nil && *rowStatus == store.Archived { + if userID == 0 { + impossibleCreatorID := int32(-1) + find.CreatorID = &impossibleCreatorID + return + } + find.CreatorID = &userID + return + } + if userID == 0 { + find.VisibilityList = []store.Visibility{store.Public} + return + } + find.Filters = append(find.Filters, "creator_id == "+itoa32(userID)+` || visibility in ["PUBLIC", "PROTECTED"]`) +} + +func (s *MCPService) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment, userID int32) error { + if attachment.CreatorID == userID { + return nil + } + if attachment.MemoID == nil { + return errors.New("permission denied") + } + + memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID}) + if err != nil { + return errors.Wrap(err, "failed to get linked memo") + } + if memo == nil { + return errors.New("linked memo not found") + } + return checkMemoAccess(memo, userID) +} + +func (s *MCPService) isAllowedOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + + originURL, err := url.Parse(origin) + if err != nil || originURL.Scheme == "" || originURL.Host == "" { + return false + } + + if sameOriginHost(originURL.Host, r.Host) { + return true + } + + if s.profile.InstanceURL == "" { + return false + } + + instanceURL, err := url.Parse(s.profile.InstanceURL) + if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" { + return false + } + + return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && sameOriginHost(originURL.Host, instanceURL.Host) +} + +func sameOriginHost(a, b string) bool { + return strings.EqualFold(a, b) +} + +func itoa32(v int32) string { + return strconv.FormatInt(int64(v), 10) +} diff --git a/server/router/mcp/mcp.go b/server/router/mcp/mcp.go index dc499487c..93dcb6c82 100644 --- a/server/router/mcp/mcp.go +++ b/server/router/mcp/mcp.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/labstack/echo/v5" - "github.com/labstack/echo/v5/middleware" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/usememos/memos/internal/profile" @@ -44,11 +43,22 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) mcpGroup := echoServer.Group("") - mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: []string{"*"}, - })) mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { + if !s.isAllowedOrigin(c.Request()) { + return c.JSON(http.StatusForbidden, map[string]string{"message": "invalid origin"}) + } + if origin := c.Request().Header.Get("Origin"); origin != "" { + headers := c.Response().Header() + headers.Set("Vary", "Origin") + headers.Set("Access-Control-Allow-Origin", origin) + headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID") + headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + if c.Request().Method == http.MethodOptions { + return c.NoContent(http.StatusNoContent) + } + } + authHeader := c.Request().Header.Get("Authorization") if authHeader != "" { result := s.authenticator.Authenticate(c.Request().Context(), authHeader) diff --git a/server/router/mcp/mcp_test.go b/server/router/mcp/mcp_test.go new file mode 100644 index 000000000..a4dd1c489 --- /dev/null +++ b/server/router/mcp/mcp_test.go @@ -0,0 +1,275 @@ +package mcp + +import ( + "context" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/lithammer/shortuuid/v4" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/internal/profile" + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" + teststore "github.com/usememos/memos/store/test" +) + +type testMCPService struct { + service *MCPService + store *store.Store +} + +func newTestMCPService(t *testing.T) *testMCPService { + t.Helper() + + ctx := context.Background() + stores := teststore.NewTestingStore(ctx, t) + t.Cleanup(func() { + require.NoError(t, stores.Close()) + }) + + svc := NewMCPService(&profile.Profile{ + Driver: "sqlite", + InstanceURL: "https://notes.example.com", + }, stores, "test-secret") + return &testMCPService{ + service: svc, + store: stores, + } +} + +func (s *testMCPService) createUser(t *testing.T, username string) *store.User { + t.Helper() + + user, err := s.store.CreateUser(context.Background(), &store.User{ + Username: username, + Role: store.RoleUser, + Email: username + "@example.com", + }) + require.NoError(t, err) + return user +} + +func (s *testMCPService) createMemo(t *testing.T, creatorID int32, visibility store.Visibility, content string) *store.Memo { + t.Helper() + + memo, err := s.store.CreateMemo(context.Background(), &store.Memo{ + UID: shortuuid.New(), + CreatorID: creatorID, + RowStatus: store.Normal, + Visibility: visibility, + Content: content, + }) + require.NoError(t, err) + return memo +} + +func (s *testMCPService) archiveMemo(t *testing.T, memoID int32) { + t.Helper() + + rowStatus := store.Archived + require.NoError(t, s.store.UpdateMemo(context.Background(), &store.UpdateMemo{ + ID: memoID, + RowStatus: &rowStatus, + })) +} + +func (s *testMCPService) createAttachment(t *testing.T, creatorID int32, memoID *int32) *store.Attachment { + t.Helper() + + attachment, err := s.store.CreateAttachment(context.Background(), &store.Attachment{ + UID: shortuuid.New(), + CreatorID: creatorID, + Filename: "note.txt", + Type: "text/plain", + Size: 4, + StorageType: storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED, + Reference: "db://attachment/note.txt", + MemoID: memoID, + }) + require.NoError(t, err) + return attachment +} + +func withUser(ctx context.Context, userID int32) context.Context { + return context.WithValue(ctx, auth.UserIDContextKey, userID) +} + +func toolRequest(name string, arguments map[string]any) mcp.CallToolRequest { + return mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: name, + Arguments: arguments, + }, + } +} + +func firstText(t *testing.T, result *mcp.CallToolResult) string { + t.Helper() + require.NotEmpty(t, result.Content) + text, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok) + return text.Text +} + +func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + other := ts.createUser(t, "other") + + memo := ts.createMemo(t, owner.ID, store.Public, "archived") + ts.archiveMemo(t, memo.ID) + + ctx := withUser(context.Background(), other.ID) + result, err := ts.service.handleGetMemo(ctx, toolRequest("get_memo", map[string]any{ + "name": "memos/" + memo.UID, + })) + require.NoError(t, err) + require.True(t, result.IsError) + require.Contains(t, firstText(t, result), "permission denied") + + _, err = ts.service.handleReadMemoResource(ctx, mcp.ReadResourceRequest{ + Params: mcp.ReadResourceParams{ + URI: "memo://memos/" + memo.UID, + }, + }) + require.ErrorContains(t, err, "permission denied") +} + +func TestHandleListMemosArchivedOnlyReturnsCreatorMemos(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + other := ts.createUser(t, "other") + + ownerMemo := ts.createMemo(t, owner.ID, store.Public, "owner archived") + ts.archiveMemo(t, ownerMemo.ID) + otherMemo := ts.createMemo(t, other.ID, store.Public, "other archived") + ts.archiveMemo(t, otherMemo.ID) + + result, err := ts.service.handleListMemos(withUser(context.Background(), owner.ID), toolRequest("list_memos", map[string]any{ + "state": "ARCHIVED", + })) + require.NoError(t, err) + require.False(t, result.IsError) + + var payload struct { + Memos []memoJSON `json:"memos"` + } + require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &payload)) + require.Len(t, payload.Memos, 1) + require.Equal(t, "memos/"+ownerMemo.UID, payload.Memos[0].Name) + + anonResult, err := ts.service.handleListMemos(context.Background(), toolRequest("list_memos", map[string]any{ + "state": "ARCHIVED", + })) + require.NoError(t, err) + require.NoError(t, json.Unmarshal([]byte(firstText(t, anonResult)), &payload)) + require.Empty(t, payload.Memos) +} + +func TestHandleListMemoRelationsFiltersUnreadableTargets(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + privateUser := ts.createUser(t, "private-user") + publicUser := ts.createUser(t, "public-user") + + source := ts.createMemo(t, owner.ID, store.Public, "source") + privateTarget := ts.createMemo(t, privateUser.ID, store.Private, "private") + publicTarget := ts.createMemo(t, publicUser.ID, store.Public, "public") + + _, err := ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{ + MemoID: source.ID, + RelatedMemoID: privateTarget.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + _, err = ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{ + MemoID: source.ID, + RelatedMemoID: publicTarget.ID, + Type: store.MemoRelationReference, + }) + require.NoError(t, err) + + result, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{ + "name": "memos/" + source.UID, + })) + require.NoError(t, err) + require.False(t, result.IsError) + + var relations []relationJSON + require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &relations)) + require.Len(t, relations, 1) + require.Equal(t, "memos/"+publicTarget.UID, relations[0].RelatedMemo) + + denied, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{ + "name": "memos/" + privateTarget.UID, + })) + require.NoError(t, err) + require.True(t, denied.IsError) + require.Contains(t, firstText(t, denied), "permission denied") +} + +func TestHandleLinkAttachmentToMemoRequiresMemoOwnership(t *testing.T) { + ts := newTestMCPService(t) + attachmentOwner := ts.createUser(t, "attachment-owner") + memoOwner := ts.createUser(t, "memo-owner") + + attachment := ts.createAttachment(t, attachmentOwner.ID, nil) + memo := ts.createMemo(t, memoOwner.ID, store.Public, "target") + + result, err := ts.service.handleLinkAttachmentToMemo(withUser(context.Background(), attachmentOwner.ID), toolRequest("link_attachment_to_memo", map[string]any{ + "name": "attachments/" + attachment.UID, + "memo": "memos/" + memo.UID, + })) + require.NoError(t, err) + require.True(t, result.IsError) + require.Contains(t, firstText(t, result), "permission denied") +} + +func TestHandleGetAttachmentDeniesArchivedLinkedMemoToNonCreator(t *testing.T) { + ts := newTestMCPService(t) + owner := ts.createUser(t, "owner") + other := ts.createUser(t, "other") + + memo := ts.createMemo(t, owner.ID, store.Public, "memo") + ts.archiveMemo(t, memo.ID) + attachment := ts.createAttachment(t, owner.ID, &memo.ID) + + result, err := ts.service.handleGetAttachment(withUser(context.Background(), other.ID), toolRequest("get_attachment", map[string]any{ + "name": "attachments/" + attachment.UID, + })) + require.NoError(t, err) + require.True(t, result.IsError) + require.Contains(t, firstText(t, result), "permission denied") +} + +func TestIsAllowedOrigin(t *testing.T) { + ts := newTestMCPService(t) + + t.Run("allow missing origin", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil) + require.True(t, ts.service.isAllowedOrigin(req)) + }) + + t.Run("allow same origin as request host", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil) + req.Header.Set("Origin", "http://localhost:5230") + require.True(t, ts.service.isAllowedOrigin(req)) + }) + + t.Run("allow configured instance origin", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://127.0.0.1:5230/mcp", nil) + req.Host = "127.0.0.1:5230" + req.Header.Set("Origin", "https://notes.example.com") + require.True(t, ts.service.isAllowedOrigin(req)) + }) + + t.Run("reject cross origin", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil) + req.Header.Set("Origin", "https://evil.example.com") + require.False(t, ts.service.isAllowedOrigin(req)) + }) +} diff --git a/server/router/mcp/tools_attachment.go b/server/router/mcp/tools_attachment.go index 4bcd098a6..2e2b3f571 100644 --- a/server/router/mcp/tools_attachment.go +++ b/server/router/mcp/tools_attachment.go @@ -216,21 +216,8 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe return mcp.NewToolResultError("attachment not found"), nil } - // Check access: creator can always access; linked memo visibility applies otherwise. - if attachment.CreatorID != userID { - if attachment.MemoID != nil { - memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID}) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get linked memo: %v", err)), nil - } - if memo != nil { - if err := checkMemoAccess(memo, userID); err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - } - } else { - return mcp.NewToolResultError("permission denied"), nil - } + if err := s.checkAttachmentAccess(ctx, attachment, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } result, err := storeAttachmentToJSON(ctx, s.store, attachment) @@ -302,6 +289,9 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal if memo == nil { return mcp.NewToolResultError("memo not found"), nil } + if err := checkMemoOwnership(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{ ID: attachment.ID, diff --git a/server/router/mcp/tools_memo.go b/server/router/mcp/tools_memo.go index 2fab7fbe5..2e106c805 100644 --- a/server/router/mcp/tools_memo.go +++ b/server/router/mcp/tools_memo.go @@ -168,33 +168,6 @@ func storeMemoToJSONWithUsernames(m *store.Memo, usernamesByID map[int32]string) return j, nil } -// checkMemoAccess returns an error if the caller cannot read memo. -// userID == 0 means anonymous. -func checkMemoAccess(memo *store.Memo, userID int32) error { - switch memo.Visibility { - case store.Protected: - if userID == 0 { - return errors.New("permission denied") - } - case store.Private: - if memo.CreatorID != userID { - return errors.New("permission denied") - } - default: - // store.Public and any unknown visibility: allow - } - return nil -} - -// applyVisibilityFilter restricts find to memos the caller may see. -func applyVisibilityFilter(find *store.FindMemo, userID int32) { - if userID == 0 { - find.VisibilityList = []store.Visibility{store.Public} - } else { - find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID)) - } -} - // parseMemoUID extracts the UID from a "memos/" resource name. func parseMemoUID(name string) (string, error) { uid, ok := strings.CutPrefix(name, "memos/") @@ -337,7 +310,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques Offset: &offset, OrderByPinned: req.GetBool("order_by_pinned", false), } - applyVisibilityFilter(find, userID) + applyVisibilityFilter(find, userID, rowStatus) if filter := req.GetString("filter", ""); filter != "" { find.Filters = append(find.Filters, filter) } @@ -465,8 +438,8 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque if memo == nil { return mcp.NewToolResultError("memo not found"), nil } - if memo.CreatorID != userID { - return mcp.NewToolResultError("permission denied"), nil + if err := checkMemoOwnership(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } update := &store.UpdateMemo{ID: memo.ID} @@ -533,8 +506,8 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque if memo == nil { return mcp.NewToolResultError("memo not found"), nil } - if memo.CreatorID != userID { - return mcp.NewToolResultError("permission denied"), nil + if err := checkMemoOwnership(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { @@ -561,7 +534,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ Offset: &zero, Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)}, } - applyVisibilityFilter(find, userID) + applyVisibilityFilter(find, userID, find.RowStatus) memos, err := s.store.ListMemos(ctx, find) if err != nil { diff --git a/server/router/mcp/tools_relation.go b/server/router/mcp/tools_relation.go index 773f63eb3..6a7e886ab 100644 --- a/server/router/mcp/tools_relation.go +++ b/server/router/mcp/tools_relation.go @@ -7,6 +7,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) @@ -40,6 +41,8 @@ func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) { } func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := auth.GetUserID(ctx) + uid, err := parseMemoUID(req.GetString("name", "")) if err != nil { return mcp.NewToolResultError(err.Error()), nil @@ -52,6 +55,9 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo if memo == nil { return mcp.NewToolResultError("memo not found"), nil } + if err := checkMemoAccess(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } find := &store.FindMemoRelation{ MemoIDList: []int32{memo.ID}, @@ -85,21 +91,24 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memos: %v", err)), nil } - uidByID := make(map[int32]string, len(memos)) + memoByID := make(map[int32]*store.Memo, len(memos)) for _, m := range memos { - uidByID[m.ID] = m.UID + memoByID[m.ID] = m } results := make([]relationJSON, 0, len(relations)) for _, r := range relations { - memoUID, ok1 := uidByID[r.MemoID] - relatedUID, ok2 := uidByID[r.RelatedMemoID] + srcMemo, ok1 := memoByID[r.MemoID] + relatedMemo, ok2 := memoByID[r.RelatedMemoID] if !ok1 || !ok2 { continue } + if checkMemoAccess(srcMemo, userID) != nil || checkMemoAccess(relatedMemo, userID) != nil { + continue + } results = append(results, relationJSON{ - Memo: "memos/" + memoUID, - RelatedMemo: "memos/" + relatedUID, + Memo: "memos/" + srcMemo.UID, + RelatedMemo: "memos/" + relatedMemo.UID, Type: string(r.Type), }) } @@ -133,7 +142,7 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if srcMemo.CreatorID != userID { + if err := checkMemoOwnership(srcMemo, userID); err != nil { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } @@ -144,6 +153,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT if dstMemo == nil { return mcp.NewToolResultError("related memo not found"), nil } + if err := checkMemoAccess(dstMemo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ MemoID: srcMemo.ID, @@ -187,7 +199,7 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT if srcMemo == nil { return mcp.NewToolResultError("source memo not found"), nil } - if srcMemo.CreatorID != userID { + if err := checkMemoOwnership(srcMemo, userID); err != nil { return mcp.NewToolResultError("permission denied: must own the source memo"), nil } @@ -198,6 +210,9 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT if dstMemo == nil { return mcp.NewToolResultError("related memo not found"), nil } + if err := checkMemoAccess(dstMemo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } refType := store.MemoRelationReference if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ diff --git a/server/router/mcp/tools_tag.go b/server/router/mcp/tools_tag.go index 55fabd632..ded3c5849 100644 --- a/server/router/mcp/tools_tag.go +++ b/server/router/mcp/tools_tag.go @@ -32,7 +32,7 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) ExcludeContent: true, RowStatus: &rowStatus, } - applyVisibilityFilter(find, userID) + applyVisibilityFilter(find, userID, find.RowStatus) memos, err := s.store.ListMemos(ctx, find) if err != nil {