From 803d488a5f8f55477cd3ad4cc4cf0fac98901dd3 Mon Sep 17 00:00:00 2001 From: Johnny Date: Sun, 1 Mar 2026 23:10:23 +0800 Subject: [PATCH] feat(mcp): refactor MCP server to standard protocol structure - Replace PAT-only auth with optional auth supporting both PAT and JWT via auth.Authenticator.Authenticate(); unauthenticated requests see only public memos, matching REST API visibility semantics - Inline auth middleware into mcp.go following fileserver pattern; remove auth_middleware.go - Introduce memoJSON response type that correctly serialises store.Memo (including Payload.Tags and Payload.Property) without proto marshalling - Add tools: list_memo_comments, create_memo_comment, list_tags - Extend list_memos with state (NORMAL/ARCHIVED), order_by_pinned, and page parameters - Extend update_memo with pinned and state parameters - Extract #tags from content on create/update via regex to pre-populate Payload.Tags without requiring a full markdown service rebuild - Add MCP Resources: memo://memos/{uid} template returns memo as Markdown with YAML frontmatter, allowing clients to read memos by URI - Add MCP Prompts: capture (save a thought) and review (search + summarise) --- server/router/mcp/auth_middleware.go | 31 -- server/router/mcp/mcp.go | 34 +- server/router/mcp/prompts.go | 84 +++++ server/router/mcp/resources_memo.go | 85 +++++ server/router/mcp/tools_memo.go | 503 +++++++++++++++++++++------ server/router/mcp/tools_tag.go | 68 ++++ 6 files changed, 658 insertions(+), 147 deletions(-) delete mode 100644 server/router/mcp/auth_middleware.go create mode 100644 server/router/mcp/prompts.go create mode 100644 server/router/mcp/resources_memo.go create mode 100644 server/router/mcp/tools_tag.go diff --git a/server/router/mcp/auth_middleware.go b/server/router/mcp/auth_middleware.go deleted file mode 100644 index 02e9a2f2f..000000000 --- a/server/router/mcp/auth_middleware.go +++ /dev/null @@ -1,31 +0,0 @@ -package mcp - -import ( - "net/http" - - "github.com/labstack/echo/v5" - - "github.com/usememos/memos/server/auth" - "github.com/usememos/memos/store" -) - -func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc { - authenticator := auth.NewAuthenticator(s, secret) - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { - token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization")) - if token == "" { - return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"}) - } - - user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token) - if err != nil || user == nil { - return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"}) - } - - ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId()) - c.SetRequest(c.Request().WithContext(ctx)) - return next(c) - } - } -} diff --git a/server/router/mcp/mcp.go b/server/router/mcp/mcp.go index f6c42f940..f7fc77218 100644 --- a/server/router/mcp/mcp.go +++ b/server/router/mcp/mcp.go @@ -1,25 +1,36 @@ package mcp import ( + "net/http" + "github.com/labstack/echo/v5" "github.com/labstack/echo/v5/middleware" mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) type MCPService struct { - store *store.Store - secret string + store *store.Store + authenticator *auth.Authenticator } func NewMCPService(store *store.Store, secret string) *MCPService { - return &MCPService{store: store, secret: secret} + return &MCPService{ + store: store, + authenticator: auth.NewAuthenticator(store, secret), + } } func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { - mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0", mcpserver.WithToolCapabilities(false)) + mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0", + mcpserver.WithToolCapabilities(false), + ) s.registerMemoTools(mcpSrv) + s.registerTagTools(mcpSrv) + s.registerMemoResources(mcpSrv) + s.registerPrompts(mcpSrv) httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) @@ -27,6 +38,19 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{"*"}, })) - mcpGroup.Use(newAuthMiddleware(s.store, s.secret)) + mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) error { + authHeader := c.Request().Header.Get("Authorization") + if authHeader != "" { + result := s.authenticator.Authenticate(c.Request().Context(), authHeader) + if result == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired token"}) + } + ctx := auth.ApplyToContext(c.Request().Context(), result) + c.SetRequest(c.Request().WithContext(ctx)) + } + return next(c) + } + }) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) } diff --git a/server/router/mcp/prompts.go b/server/router/mcp/prompts.go new file mode 100644 index 000000000..2e05ccc91 --- /dev/null +++ b/server/router/mcp/prompts.go @@ -0,0 +1,84 @@ +package mcp + +import ( + "context" + "errors" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" +) + +func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) { + // capture — turns free-form user input into a structured create_memo call. + mcpSrv.AddPrompt( + mcp.NewPrompt("capture", + mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+ + "Use this prompt when the user wants to quickly save something. "+ + "The assistant will call create_memo with the provided content."), + mcp.WithArgument("content", + mcp.ArgumentDescription("The text to save as a memo"), + mcp.RequiredArgument(), + ), + mcp.WithArgument("tags", + mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""), + ), + ), + s.handleCapturePrompt, + ) + + // review — surfaces existing memos on a topic for summarisation. + mcpSrv.AddPrompt( + mcp.NewPrompt("review", + mcp.WithPromptDescription("Search and review memos on a given topic. "+ + "The assistant will call search_memos and summarise the results."), + mcp.WithArgument("topic", + mcp.ArgumentDescription("Topic or keyword to search for"), + mcp.RequiredArgument(), + ), + ), + s.handleReviewPrompt, + ) +} + +func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + content := req.Params.Arguments["content"] + if content == "" { + return nil, errors.New("content argument is required") + } + + tags := req.Params.Arguments["tags"] + instruction := fmt.Sprintf( + "Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s", + content, + ) + if tags != "" { + instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags) + } + + return &mcp.GetPromptResult{ + Description: "Capture a memo", + Messages: []mcp.PromptMessage{ + mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)), + }, + }, nil +} + +func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + topic := req.Params.Arguments["topic"] + if topic == "" { + return nil, errors.New("topic argument is required") + } + + instruction := fmt.Sprintf( + "Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.", + topic, + ) + + return &mcp.GetPromptResult{ + Description: fmt.Sprintf("Review memos about %q", topic), + Messages: []mcp.PromptMessage{ + mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)), + }, + }, nil +} diff --git a/server/router/mcp/resources_memo.go b/server/router/mcp/resources_memo.go new file mode 100644 index 000000000..b7a56ab3d --- /dev/null +++ b/server/router/mcp/resources_memo.go @@ -0,0 +1,85 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/pkg/errors" + + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" +) + +// Memo resource URI scheme: memo://memos/{uid} +// Clients can read any memo they have access to by URI without calling a tool. + +func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) { + mcpSrv.AddResourceTemplate( + mcp.NewResourceTemplate( + "memo://memos/{uid}", + "Memo", + mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."), + mcp.WithTemplateMIMEType("text/markdown"), + ), + s.handleReadMemoResource, + ) +} + +func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + userID := auth.GetUserID(ctx) + + // URI format: memo://memos/{uid} + uid := strings.TrimPrefix(req.Params.URI, "memo://memos/") + if uid == req.Params.URI || uid == "" { + return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/", req.Params.URI) + } + + memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) + if err != nil { + return nil, errors.Wrap(err, "failed to get memo") + } + if memo == nil { + return nil, errors.Errorf("memo not found: %s", uid) + } + if err := checkMemoAccess(memo, userID); err != nil { + return nil, err + } + + j := storeMemoToJSON(memo) + text := formatMemoMarkdown(j) + + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "text/markdown", + Text: text, + }, + }, nil +} + +// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header. +func formatMemoMarkdown(j memoJSON) string { + var sb strings.Builder + + sb.WriteString("---\n") + fmt.Fprintf(&sb, "name: %s\n", j.Name) + fmt.Fprintf(&sb, "creator: %s\n", j.Creator) + fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility) + fmt.Fprintf(&sb, "state: %s\n", j.State) + fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned) + if len(j.Tags) > 0 { + fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", ")) + } + fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime) + fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime) + if j.Parent != "" { + fmt.Fprintf(&sb, "parent: %s\n", j.Parent) + } + sb.WriteString("---\n\n") + sb.WriteString(j.Content) + + return sb.String() +} diff --git a/server/router/mcp/tools_memo.go b/server/router/mcp/tools_memo.go index 556c4f9a1..9179c92c2 100644 --- a/server/router/mcp/tools_memo.go +++ b/server/router/mcp/tools_memo.go @@ -3,22 +3,166 @@ package mcp import ( "context" "encoding/json" - "errors" "fmt" + "regexp" "strings" "github.com/lithammer/shortuuid/v4" "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/pkg/errors" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) +// tagRegexp matches #tag patterns in memo content. +// A tag must start with a letter and contain no whitespace or # characters. +var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`) + +// extractTags does a best-effort extraction of #tags from raw markdown content. +// It is used when creating or updating memos via MCP to pre-populate Payload.Tags. +// The full markdown service may later rebuild a more accurate payload. +func extractTags(content string) []string { + matches := tagRegexp.FindAllStringSubmatch(content, -1) + seen := make(map[string]struct{}, len(matches)) + tags := make([]string, 0, len(matches)) + for _, m := range matches { + tag := m[1] + if _, ok := seen[tag]; !ok { + seen[tag] = struct{}{} + tags = append(tags, tag) + } + } + return tags +} + +// buildPayload constructs a MemoPayload with tags extracted from content. +// Returns nil when no tags are found so the store omits the payload entirely. +func buildPayload(content string) *storepb.MemoPayload { + tags := extractTags(content) + if len(tags) == 0 { + return nil + } + return &storepb.MemoPayload{Tags: tags} +} + +// propertyJSON is the serialisable form of MemoPayload.Property. +type propertyJSON struct { + HasLink bool `json:"has_link"` + HasTaskList bool `json:"has_task_list"` + HasCode bool `json:"has_code"` + HasIncompleteTasks bool `json:"has_incomplete_tasks"` +} + +// memoJSON is the canonical response shape for all MCP memo results. +// It serialises correctly with standard encoding/json (no proto marshalling needed). +type memoJSON struct { + Name string `json:"name"` + Creator string `json:"creator"` + CreateTime int64 `json:"create_time"` + UpdateTime int64 `json:"update_time"` + Content string `json:"content,omitempty"` + Visibility string `json:"visibility"` + Tags []string `json:"tags"` + Pinned bool `json:"pinned"` + State string `json:"state"` + Property *propertyJSON `json:"property,omitempty"` + Parent string `json:"parent,omitempty"` +} + +func storeMemoToJSON(m *store.Memo) memoJSON { + j := memoJSON{ + Name: "memos/" + m.UID, + Creator: fmt.Sprintf("users/%d", m.CreatorID), + CreateTime: m.CreatedTs, + UpdateTime: m.UpdatedTs, + Content: m.Content, + Visibility: string(m.Visibility), + Pinned: m.Pinned, + State: string(m.RowStatus), + Tags: []string{}, + } + if m.Payload != nil { + if len(m.Payload.Tags) > 0 { + j.Tags = m.Payload.Tags + } + if p := m.Payload.Property; p != nil && (p.HasLink || p.HasTaskList || p.HasCode || p.HasIncompleteTasks) { + j.Property = &propertyJSON{ + HasLink: p.HasLink, + HasTaskList: p.HasTaskList, + HasCode: p.HasCode, + HasIncompleteTasks: p.HasIncompleteTasks, + } + } + } + if m.ParentUID != nil { + j.Parent = "memos/" + *m.ParentUID + } + return j +} + +// checkMemoAccess returns an error if the caller cannot read memo. +// userID == 0 means anonymous. +func checkMemoAccess(memo *store.Memo, userID int32) error { + switch memo.Visibility { + case store.Protected: + if userID == 0 { + return errors.New("permission denied") + } + case store.Private: + if memo.CreatorID != userID { + return errors.New("permission denied") + } + default: + // store.Public and any unknown visibility: allow + } + return nil +} + +// applyVisibilityFilter restricts find to memos the caller may see. +func applyVisibilityFilter(find *store.FindMemo, userID int32) { + if userID == 0 { + find.VisibilityList = []store.Visibility{store.Public} + } else { + find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID)) + } +} + +// parseMemoUID extracts the UID from a "memos/" resource name. +func parseMemoUID(name string) (string, error) { + uid, ok := strings.CutPrefix(name, "memos/") + if !ok || uid == "" { + return "", errors.Errorf(`memo name must be in the format "memos/", got %q`, name) + } + return uid, nil +} + +// parseVisibility validates a visibility string and returns the store constant. +func parseVisibility(s string) (store.Visibility, error) { + switch v := store.Visibility(s); v { + case store.Public, store.Protected, store.Private: + return v, nil + default: + return "", errors.Errorf("visibility must be PRIVATE, PROTECTED, or PUBLIC; got %q", s) + } +} + +// parseRowStatus validates a state string and returns the store constant. +func parseRowStatus(s string) (store.RowStatus, error) { + switch rs := store.RowStatus(s); rs { + case store.Normal, store.Archived: + return rs, nil + default: + return "", errors.Errorf("state must be NORMAL or ARCHIVED; got %q", s) + } +} + func extractUserID(ctx context.Context) (int32, error) { id := auth.GetUserID(ctx) if id == 0 { - return 0, errors.New("unauthenticated") + return 0, errors.New("unauthenticated: a personal access token is required") } return id, nil } @@ -32,58 +176,71 @@ func marshalJSON(v any) (string, error) { } func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) { - listTool := mcp.NewTool("list_memos", - mcp.WithDescription("List the authenticated user's memos"), - mcp.WithNumber("page_size", mcp.Description("Max memos to return, default 20")), - mcp.WithString("filter", mcp.Description(`CEL filter expression, e.g. content.contains("keyword")`)), - ) - mcpSrv.AddTool(listTool, s.handleListMemos) + mcpSrv.AddTool(mcp.NewTool("list_memos", + mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."), + mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")), + mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")), + mcp.WithString("state", + mcp.Enum("NORMAL", "ARCHIVED"), + mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"), + ), + mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")), + mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), + ), s.handleListMemos) - getTool := mcp.NewTool("get_memo", - mcp.WithDescription("Get a single memo by resource name"), + mcpSrv.AddTool(mcp.NewTool("get_memo", + mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), - ) - mcpSrv.AddTool(getTool, s.handleGetMemo) + ), s.handleGetMemo) - createTool := mcp.NewTool("create_memo", - mcp.WithDescription("Create a new memo"), - mcp.WithString("content", mcp.Required(), mcp.Description("Memo content")), + mcpSrv.AddTool(mcp.NewTool("create_memo", + mcp.WithDescription("Create a new memo. Requires authentication."), + mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")), mcp.WithString("visibility", mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), - mcp.Description("Visibility: PRIVATE (default), PROTECTED, or PUBLIC"), + mcp.Description("Visibility (default: PRIVATE)"), ), - ) - mcpSrv.AddTool(createTool, s.handleCreateMemo) + ), s.handleCreateMemo) - updateTool := mcp.NewTool("update_memo", - mcp.WithDescription("Update a memo's content or visibility"), + mcpSrv.AddTool(mcp.NewTool("update_memo", + mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), - mcp.WithString("content", mcp.Description("New content (omit to leave unchanged)")), + mcp.WithString("content", mcp.Description("New Markdown content")), mcp.WithString("visibility", mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), - mcp.Description("New visibility (omit to leave unchanged)"), + mcp.Description("New visibility"), ), - ) - mcpSrv.AddTool(updateTool, s.handleUpdateMemo) + mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")), + mcp.WithString("state", + mcp.Enum("NORMAL", "ARCHIVED"), + mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"), + ), + ), s.handleUpdateMemo) - deleteTool := mcp.NewTool("delete_memo", - mcp.WithDescription("Delete a memo"), + mcpSrv.AddTool(mcp.NewTool("delete_memo", + mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), - ) - mcpSrv.AddTool(deleteTool, s.handleDeleteMemo) + ), s.handleDeleteMemo) - searchTool := mcp.NewTool("search_memos", - mcp.WithDescription("Search memo content using a text query"), - mcp.WithString("query", mcp.Required(), mcp.Description("Text to search in memo content")), - ) - mcpSrv.AddTool(searchTool, s.handleSearchMemos) + mcpSrv.AddTool(mcp.NewTool("search_memos", + mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."), + mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")), + ), s.handleSearchMemos) + + mcpSrv.AddTool(mcp.NewTool("list_memo_comments", + mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."), + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), + ), s.handleListMemoComments) + + mcpSrv.AddTool(mcp.NewTool("create_memo_comment", + mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."), + mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)), + mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")), + ), s.handleCreateMemoComment) } func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + userID := auth.GetUserID(ctx) pageSize := req.GetInt("page_size", 20) if pageSize <= 0 { @@ -92,31 +249,54 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques if pageSize > 100 { pageSize = 100 } - filterExpr := req.GetString("filter", "") - - rowStatus := store.Normal - limitPlusOne := pageSize + 1 - zero := 0 - find := &store.FindMemo{ - CreatorID: &userID, - ExcludeComments: true, - RowStatus: &rowStatus, - Limit: &limitPlusOne, - Offset: &zero, + page := req.GetInt("page", 0) + if page < 0 { + page = 0 } - if filterExpr != "" { - find.Filters = append(find.Filters, filterExpr) + + var rowStatus *store.RowStatus + if state := req.GetString("state", "NORMAL"); state != "" { + rs, err := parseRowStatus(state) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + rowStatus = &rs + } + + limit := pageSize + 1 + offset := page * pageSize + find := &store.FindMemo{ + ExcludeComments: true, + RowStatus: rowStatus, + Limit: &limit, + Offset: &offset, + OrderByPinned: req.GetBool("order_by_pinned", false), + } + applyVisibilityFilter(find, userID) + if filter := req.GetString("filter", ""); filter != "" { + find.Filters = append(find.Filters, filter) } memos, err := s.store.ListMemos(ctx, find) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil } - if len(memos) == limitPlusOne { + + hasMore := len(memos) > pageSize + if hasMore { memos = memos[:pageSize] } - out, err := marshalJSON(memos) + results := make([]memoJSON, len(memos)) + for i, m := range memos { + results[i] = storeMemoToJSON(m) + } + + type listResponse struct { + Memos []memoJSON `json:"memos"` + HasMore bool `json:"has_more"` + } + out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore}) if err != nil { return nil, err } @@ -124,20 +304,13 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques } func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) + userID := auth.GetUserID(ctx) + + uid, err := parseMemoUID(req.GetString("name", "")) if err != nil { return mcp.NewToolResultError(err.Error()), nil } - name := req.GetString("name", "") - if name == "" { - return mcp.NewToolResultError("name is required"), nil - } - uid, found := strings.CutPrefix(name, "memos/") - if !found || uid == "" { - return mcp.NewToolResultError(`name must be in the format "memos/"`), nil - } - memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil @@ -145,11 +318,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) if memo == nil { return mcp.NewToolResultError("memo not found"), nil } - if memo.Visibility == store.Private && memo.CreatorID != userID { - return mcp.NewToolResultError("permission denied"), nil + if err := checkMemoAccess(memo, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil } - out, err := marshalJSON(memo) + out, err := marshalJSON(storeMemoToJSON(memo)) if err != nil { return nil, err } @@ -166,26 +339,23 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque if content == "" { return mcp.NewToolResultError("content is required"), nil } - - visibility := req.GetString("visibility", "PRIVATE") - switch visibility { - case "PRIVATE", "PROTECTED", "PUBLIC": - default: - return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil + visibility, err := parseVisibility(req.GetString("visibility", "PRIVATE")) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - create := &store.Memo{ + memo, err := s.store.CreateMemo(ctx, &store.Memo{ UID: shortuuid.New(), CreatorID: userID, Content: content, - Visibility: store.Visibility(visibility), - } - memo, err := s.store.CreateMemo(ctx, create) + Visibility: visibility, + Payload: buildPayload(content), + }) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil } - out, err := marshalJSON(memo) + out, err := marshalJSON(storeMemoToJSON(memo)) if err != nil { return nil, err } @@ -198,13 +368,9 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque return mcp.NewToolResultError(err.Error()), nil } - name := req.GetString("name", "") - if name == "" { - return mcp.NewToolResultError("name is required"), nil - } - uid, found := strings.CutPrefix(name, "memos/") - if !found || uid == "" { - return mcp.NewToolResultError(`name must be in the format "memos/"`), nil + uid, err := parseMemoUID(req.GetString("name", "")) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) @@ -219,17 +385,29 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque } update := &store.UpdateMemo{ID: memo.ID} - if content := req.GetString("content", ""); content != "" { - update.Content = &content + args := req.GetArguments() + + if v := req.GetString("content", ""); v != "" { + update.Content = &v + update.Payload = buildPayload(v) } - if vis := req.GetString("visibility", ""); vis != "" { - switch vis { - case "PRIVATE", "PROTECTED", "PUBLIC": - default: - return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil + if v := req.GetString("visibility", ""); v != "" { + vis, err := parseVisibility(v) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - v := store.Visibility(vis) - update.Visibility = &v + update.Visibility = &vis + } + if v := req.GetString("state", ""); v != "" { + rs, err := parseRowStatus(v) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + update.RowStatus = &rs + } + if _, ok := args["pinned"]; ok { + pinned := req.GetBool("pinned", false) + update.Pinned = &pinned } if err := s.store.UpdateMemo(ctx, update); err != nil { @@ -241,7 +419,7 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil } - out, err := marshalJSON(updated) + out, err := marshalJSON(storeMemoToJSON(updated)) if err != nil { return nil, err } @@ -254,13 +432,9 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque return mcp.NewToolResultError(err.Error()), nil } - name := req.GetString("name", "") - if name == "" { - return mcp.NewToolResultError("name is required"), nil - } - uid, found := strings.CutPrefix(name, "memos/") - if !found || uid == "" { - return mcp.NewToolResultError(`name must be in the format "memos/"`), nil + uid, err := parseMemoUID(req.GetString("name", "")) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) @@ -277,40 +451,147 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil } - return mcp.NewToolResultText("memo deleted"), nil + return mcp.NewToolResultText(`{"deleted":true}`), nil } func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID, err := extractUserID(ctx) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + userID := auth.GetUserID(ctx) query := req.GetString("query", "") if query == "" { return mcp.NewToolResultError("query is required"), nil } - rowStatus := store.Normal limit := 50 zero := 0 + rowStatus := store.Normal find := &store.FindMemo{ ExcludeComments: true, RowStatus: &rowStatus, Limit: &limit, Offset: &zero, - Filters: []string{ - fmt.Sprintf("creator_id == %d", userID), - fmt.Sprintf(`content.contains(%q)`, query), - }, + Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)}, } + applyVisibilityFilter(find, userID) memos, err := s.store.ListMemos(ctx, find) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil } - out, err := marshalJSON(memos) + results := make([]memoJSON, len(memos)) + for i, m := range memos { + results[i] = storeMemoToJSON(m) + } + out, err := marshalJSON(results) + if err != nil { + return nil, err + } + return mcp.NewToolResultText(out), nil +} + +func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := auth.GetUserID(ctx) + + uid, err := parseMemoUID(req.GetString("name", "")) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil + } + if parent == nil { + return mcp.NewToolResultError("memo not found"), nil + } + if err := checkMemoAccess(parent, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + relationType := store.MemoRelationComment + relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{ + RelatedMemoID: &parent.ID, + Type: &relationType, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil + } + if len(relations) == 0 { + out, _ := marshalJSON([]memoJSON{}) + return mcp.NewToolResultText(out), nil + } + + commentIDs := make([]int32, len(relations)) + for i, r := range relations { + commentIDs[i] = r.MemoID + } + + memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: commentIDs}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list comments: %v", err)), nil + } + + results := make([]memoJSON, 0, len(memos)) + for _, m := range memos { + if checkMemoAccess(m, userID) == nil { + results = append(results, storeMemoToJSON(m)) + } + } + out, err := marshalJSON(results) + if err != nil { + return nil, err + } + return mcp.NewToolResultText(out), nil +} + +func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID, err := extractUserID(ctx) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + uid, err := parseMemoUID(req.GetString("name", "")) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + content := req.GetString("content", "") + if content == "" { + return mcp.NewToolResultError("content is required"), nil + } + + parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil + } + if parent == nil { + return mcp.NewToolResultError("memo not found"), nil + } + if err := checkMemoAccess(parent, userID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + comment, err := s.store.CreateMemo(ctx, &store.Memo{ + UID: shortuuid.New(), + CreatorID: userID, + Content: content, + Visibility: parent.Visibility, + Payload: buildPayload(content), + ParentUID: &parent.UID, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil + } + + if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: comment.ID, + RelatedMemoID: parent.ID, + Type: store.MemoRelationComment, + }); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil + } + + out, err := marshalJSON(storeMemoToJSON(comment)) if err != nil { return nil, err } diff --git a/server/router/mcp/tools_tag.go b/server/router/mcp/tools_tag.go new file mode 100644 index 000000000..de00326d8 --- /dev/null +++ b/server/router/mcp/tools_tag.go @@ -0,0 +1,68 @@ +package mcp + +import ( + "context" + "fmt" + "sort" + + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" +) + +func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) { + mcpSrv.AddTool(mcp.NewTool("list_tags", + mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."), + ), s.handleListTags) +} + +type tagEntry struct { + Tag string `json:"tag"` + Count int `json:"count"` +} + +func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := auth.GetUserID(ctx) + + rowStatus := store.Normal + find := &store.FindMemo{ + ExcludeComments: true, + ExcludeContent: true, + RowStatus: &rowStatus, + } + applyVisibilityFilter(find, userID) + + memos, err := s.store.ListMemos(ctx, find) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil + } + + counts := make(map[string]int) + for _, m := range memos { + if m.Payload == nil { + continue + } + for _, tag := range m.Payload.Tags { + counts[tag]++ + } + } + + entries := make([]tagEntry, 0, len(counts)) + for tag, count := range counts { + entries = append(entries, tagEntry{Tag: tag, Count: count}) + } + sort.Slice(entries, func(i, j int) bool { + if entries[i].Count != entries[j].Count { + return entries[i].Count > entries[j].Count + } + return entries[i].Tag < entries[j].Tag + }) + + out, err := marshalJSON(entries) + if err != nil { + return nil, err + } + return mcp.NewToolResultText(out), nil +}