feat(mcp): refactor MCP server to standard protocol structure

- Replace PAT-only auth with optional auth supporting both PAT and JWT
  via auth.Authenticator.Authenticate(); unauthenticated requests see
  only public memos, matching REST API visibility semantics
- Inline auth middleware into mcp.go following fileserver pattern;
  remove auth_middleware.go
- Introduce memoJSON response type that correctly serialises store.Memo
  (including Payload.Tags and Payload.Property) without proto marshalling
- Add tools: list_memo_comments, create_memo_comment, list_tags
- Extend list_memos with state (NORMAL/ARCHIVED), order_by_pinned, and
  page parameters
- Extend update_memo with pinned and state parameters
- Extract #tags from content on create/update via regex to pre-populate
  Payload.Tags without requiring a full markdown service rebuild
- Add MCP Resources: memo://memos/{uid} template returns memo as
  Markdown with YAML frontmatter, allowing clients to read memos by URI
- Add MCP Prompts: capture (save a thought) and review (search + summarise)
This commit is contained in:
Johnny 2026-03-01 23:10:23 +08:00
parent 16576be111
commit 803d488a5f
6 changed files with 658 additions and 147 deletions

View File

@ -1,31 +0,0 @@
package mcp
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc {
authenticator := auth.NewAuthenticator(s, secret)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization"))
if token == "" {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"})
}
user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token)
if err != nil || user == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
}
ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId())
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
}
}

View File

@ -1,25 +1,36 @@
package mcp
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
type MCPService struct {
store *store.Store
secret string
store *store.Store
authenticator *auth.Authenticator
}
func NewMCPService(store *store.Store, secret string) *MCPService {
return &MCPService{store: store, secret: secret}
return &MCPService{
store: store,
authenticator: auth.NewAuthenticator(store, secret),
}
}
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0", mcpserver.WithToolCapabilities(false))
mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0",
mcpserver.WithToolCapabilities(false),
)
s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv)
s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
@ -27,6 +38,19 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"},
}))
mcpGroup.Use(newAuthMiddleware(s.store, s.secret))
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" {
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
if result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired token"})
}
ctx := auth.ApplyToContext(c.Request().Context(), result)
c.SetRequest(c.Request().WithContext(ctx))
}
return next(c)
}
})
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
}

View File

@ -0,0 +1,84 @@
package mcp
import (
"context"
"errors"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
)
func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) {
// capture — turns free-form user input into a structured create_memo call.
mcpSrv.AddPrompt(
mcp.NewPrompt("capture",
mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+
"Use this prompt when the user wants to quickly save something. "+
"The assistant will call create_memo with the provided content."),
mcp.WithArgument("content",
mcp.ArgumentDescription("The text to save as a memo"),
mcp.RequiredArgument(),
),
mcp.WithArgument("tags",
mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""),
),
),
s.handleCapturePrompt,
)
// review — surfaces existing memos on a topic for summarisation.
mcpSrv.AddPrompt(
mcp.NewPrompt("review",
mcp.WithPromptDescription("Search and review memos on a given topic. "+
"The assistant will call search_memos and summarise the results."),
mcp.WithArgument("topic",
mcp.ArgumentDescription("Topic or keyword to search for"),
mcp.RequiredArgument(),
),
),
s.handleReviewPrompt,
)
}
func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
content := req.Params.Arguments["content"]
if content == "" {
return nil, errors.New("content argument is required")
}
tags := req.Params.Arguments["tags"]
instruction := fmt.Sprintf(
"Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s",
content,
)
if tags != "" {
instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags)
}
return &mcp.GetPromptResult{
Description: "Capture a memo",
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}
func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
topic := req.Params.Arguments["topic"]
if topic == "" {
return nil, errors.New("topic argument is required")
}
instruction := fmt.Sprintf(
"Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.",
topic,
)
return &mcp.GetPromptResult{
Description: fmt.Sprintf("Review memos about %q", topic),
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}

View File

@ -0,0 +1,85 @@
package mcp
import (
"context"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// Memo resource URI scheme: memo://memos/{uid}
// Clients can read any memo they have access to by URI without calling a tool.
func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddResourceTemplate(
mcp.NewResourceTemplate(
"memo://memos/{uid}",
"Memo",
mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."),
mcp.WithTemplateMIMEType("text/markdown"),
),
s.handleReadMemoResource,
)
}
func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
userID := auth.GetUserID(ctx)
// URI format: memo://memos/{uid}
uid := strings.TrimPrefix(req.Params.URI, "memo://memos/")
if uid == req.Params.URI || uid == "" {
return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/<uid>", req.Params.URI)
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return nil, errors.Errorf("memo not found: %s", uid)
}
if err := checkMemoAccess(memo, userID); err != nil {
return nil, err
}
j := storeMemoToJSON(memo)
text := formatMemoMarkdown(j)
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "text/markdown",
Text: text,
},
}, nil
}
// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header.
func formatMemoMarkdown(j memoJSON) string {
var sb strings.Builder
sb.WriteString("---\n")
fmt.Fprintf(&sb, "name: %s\n", j.Name)
fmt.Fprintf(&sb, "creator: %s\n", j.Creator)
fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility)
fmt.Fprintf(&sb, "state: %s\n", j.State)
fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned)
if len(j.Tags) > 0 {
fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", "))
}
fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime)
fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime)
if j.Parent != "" {
fmt.Fprintf(&sb, "parent: %s\n", j.Parent)
}
sb.WriteString("---\n\n")
sb.WriteString(j.Content)
return sb.String()
}

View File

@ -3,22 +3,166 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
"github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// tagRegexp matches #tag patterns in memo content.
// A tag must start with a letter and contain no whitespace or # characters.
var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`)
// extractTags does a best-effort extraction of #tags from raw markdown content.
// It is used when creating or updating memos via MCP to pre-populate Payload.Tags.
// The full markdown service may later rebuild a more accurate payload.
func extractTags(content string) []string {
matches := tagRegexp.FindAllStringSubmatch(content, -1)
seen := make(map[string]struct{}, len(matches))
tags := make([]string, 0, len(matches))
for _, m := range matches {
tag := m[1]
if _, ok := seen[tag]; !ok {
seen[tag] = struct{}{}
tags = append(tags, tag)
}
}
return tags
}
// buildPayload constructs a MemoPayload with tags extracted from content.
// Returns nil when no tags are found so the store omits the payload entirely.
func buildPayload(content string) *storepb.MemoPayload {
tags := extractTags(content)
if len(tags) == 0 {
return nil
}
return &storepb.MemoPayload{Tags: tags}
}
// propertyJSON is the serialisable form of MemoPayload.Property.
type propertyJSON struct {
HasLink bool `json:"has_link"`
HasTaskList bool `json:"has_task_list"`
HasCode bool `json:"has_code"`
HasIncompleteTasks bool `json:"has_incomplete_tasks"`
}
// memoJSON is the canonical response shape for all MCP memo results.
// It serialises correctly with standard encoding/json (no proto marshalling needed).
type memoJSON struct {
Name string `json:"name"`
Creator string `json:"creator"`
CreateTime int64 `json:"create_time"`
UpdateTime int64 `json:"update_time"`
Content string `json:"content,omitempty"`
Visibility string `json:"visibility"`
Tags []string `json:"tags"`
Pinned bool `json:"pinned"`
State string `json:"state"`
Property *propertyJSON `json:"property,omitempty"`
Parent string `json:"parent,omitempty"`
}
func storeMemoToJSON(m *store.Memo) memoJSON {
j := memoJSON{
Name: "memos/" + m.UID,
Creator: fmt.Sprintf("users/%d", m.CreatorID),
CreateTime: m.CreatedTs,
UpdateTime: m.UpdatedTs,
Content: m.Content,
Visibility: string(m.Visibility),
Pinned: m.Pinned,
State: string(m.RowStatus),
Tags: []string{},
}
if m.Payload != nil {
if len(m.Payload.Tags) > 0 {
j.Tags = m.Payload.Tags
}
if p := m.Payload.Property; p != nil && (p.HasLink || p.HasTaskList || p.HasCode || p.HasIncompleteTasks) {
j.Property = &propertyJSON{
HasLink: p.HasLink,
HasTaskList: p.HasTaskList,
HasCode: p.HasCode,
HasIncompleteTasks: p.HasIncompleteTasks,
}
}
}
if m.ParentUID != nil {
j.Parent = "memos/" + *m.ParentUID
}
return j
}
// checkMemoAccess returns an error if the caller cannot read memo.
// userID == 0 means anonymous.
func checkMemoAccess(memo *store.Memo, userID int32) error {
switch memo.Visibility {
case store.Protected:
if userID == 0 {
return errors.New("permission denied")
}
case store.Private:
if memo.CreatorID != userID {
return errors.New("permission denied")
}
default:
// store.Public and any unknown visibility: allow
}
return nil
}
// applyVisibilityFilter restricts find to memos the caller may see.
func applyVisibilityFilter(find *store.FindMemo, userID int32) {
if userID == 0 {
find.VisibilityList = []store.Visibility{store.Public}
} else {
find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID))
}
}
// parseMemoUID extracts the UID from a "memos/<uid>" resource name.
func parseMemoUID(name string) (string, error) {
uid, ok := strings.CutPrefix(name, "memos/")
if !ok || uid == "" {
return "", errors.Errorf(`memo name must be in the format "memos/<uid>", got %q`, name)
}
return uid, nil
}
// parseVisibility validates a visibility string and returns the store constant.
func parseVisibility(s string) (store.Visibility, error) {
switch v := store.Visibility(s); v {
case store.Public, store.Protected, store.Private:
return v, nil
default:
return "", errors.Errorf("visibility must be PRIVATE, PROTECTED, or PUBLIC; got %q", s)
}
}
// parseRowStatus validates a state string and returns the store constant.
func parseRowStatus(s string) (store.RowStatus, error) {
switch rs := store.RowStatus(s); rs {
case store.Normal, store.Archived:
return rs, nil
default:
return "", errors.Errorf("state must be NORMAL or ARCHIVED; got %q", s)
}
}
func extractUserID(ctx context.Context) (int32, error) {
id := auth.GetUserID(ctx)
if id == 0 {
return 0, errors.New("unauthenticated")
return 0, errors.New("unauthenticated: a personal access token is required")
}
return id, nil
}
@ -32,58 +176,71 @@ func marshalJSON(v any) (string, error) {
}
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
listTool := mcp.NewTool("list_memos",
mcp.WithDescription("List the authenticated user's memos"),
mcp.WithNumber("page_size", mcp.Description("Max memos to return, default 20")),
mcp.WithString("filter", mcp.Description(`CEL filter expression, e.g. content.contains("keyword")`)),
)
mcpSrv.AddTool(listTool, s.handleListMemos)
mcpSrv.AddTool(mcp.NewTool("list_memos",
mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."),
mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
),
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
), s.handleListMemos)
getTool := mcp.NewTool("get_memo",
mcp.WithDescription("Get a single memo by resource name"),
mcpSrv.AddTool(mcp.NewTool("get_memo",
mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)
mcpSrv.AddTool(getTool, s.handleGetMemo)
), s.handleGetMemo)
createTool := mcp.NewTool("create_memo",
mcp.WithDescription("Create a new memo"),
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content")),
mcpSrv.AddTool(mcp.NewTool("create_memo",
mcp.WithDescription("Create a new memo. Requires authentication."),
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("Visibility: PRIVATE (default), PROTECTED, or PUBLIC"),
mcp.Description("Visibility (default: PRIVATE)"),
),
)
mcpSrv.AddTool(createTool, s.handleCreateMemo)
), s.handleCreateMemo)
updateTool := mcp.NewTool("update_memo",
mcp.WithDescription("Update a memo's content or visibility"),
mcpSrv.AddTool(mcp.NewTool("update_memo",
mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Description("New content (omit to leave unchanged)")),
mcp.WithString("content", mcp.Description("New Markdown content")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("New visibility (omit to leave unchanged)"),
mcp.Description("New visibility"),
),
)
mcpSrv.AddTool(updateTool, s.handleUpdateMemo)
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
),
), s.handleUpdateMemo)
deleteTool := mcp.NewTool("delete_memo",
mcp.WithDescription("Delete a memo"),
mcpSrv.AddTool(mcp.NewTool("delete_memo",
mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)
mcpSrv.AddTool(deleteTool, s.handleDeleteMemo)
), s.handleDeleteMemo)
searchTool := mcp.NewTool("search_memos",
mcp.WithDescription("Search memo content using a text query"),
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search in memo content")),
)
mcpSrv.AddTool(searchTool, s.handleSearchMemos)
mcpSrv.AddTool(mcp.NewTool("search_memos",
mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."),
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
), s.handleSearchMemos)
mcpSrv.AddTool(mcp.NewTool("list_memo_comments",
mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
), s.handleListMemoComments)
mcpSrv.AddTool(mcp.NewTool("create_memo_comment",
mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
), s.handleCreateMemoComment)
}
func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
userID := auth.GetUserID(ctx)
pageSize := req.GetInt("page_size", 20)
if pageSize <= 0 {
@ -92,31 +249,54 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
if pageSize > 100 {
pageSize = 100
}
filterExpr := req.GetString("filter", "")
rowStatus := store.Normal
limitPlusOne := pageSize + 1
zero := 0
find := &store.FindMemo{
CreatorID: &userID,
ExcludeComments: true,
RowStatus: &rowStatus,
Limit: &limitPlusOne,
Offset: &zero,
page := req.GetInt("page", 0)
if page < 0 {
page = 0
}
if filterExpr != "" {
find.Filters = append(find.Filters, filterExpr)
var rowStatus *store.RowStatus
if state := req.GetString("state", "NORMAL"); state != "" {
rs, err := parseRowStatus(state)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
rowStatus = &rs
}
limit := pageSize + 1
offset := page * pageSize
find := &store.FindMemo{
ExcludeComments: true,
RowStatus: rowStatus,
Limit: &limit,
Offset: &offset,
OrderByPinned: req.GetBool("order_by_pinned", false),
}
applyVisibilityFilter(find, userID)
if filter := req.GetString("filter", ""); filter != "" {
find.Filters = append(find.Filters, filter)
}
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
}
if len(memos) == limitPlusOne {
hasMore := len(memos) > pageSize
if hasMore {
memos = memos[:pageSize]
}
out, err := marshalJSON(memos)
results := make([]memoJSON, len(memos))
for i, m := range memos {
results[i] = storeMemoToJSON(m)
}
type listResponse struct {
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore})
if err != nil {
return nil, err
}
@ -124,20 +304,13 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
}
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
name := req.GetString("name", "")
if name == "" {
return mcp.NewToolResultError("name is required"), nil
}
uid, found := strings.CutPrefix(name, "memos/")
if !found || uid == "" {
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
@ -145,11 +318,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest)
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if memo.Visibility == store.Private && memo.CreatorID != userID {
return mcp.NewToolResultError("permission denied"), nil
if err := checkMemoAccess(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
out, err := marshalJSON(memo)
out, err := marshalJSON(storeMemoToJSON(memo))
if err != nil {
return nil, err
}
@ -166,26 +339,23 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque
if content == "" {
return mcp.NewToolResultError("content is required"), nil
}
visibility := req.GetString("visibility", "PRIVATE")
switch visibility {
case "PRIVATE", "PROTECTED", "PUBLIC":
default:
return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil
visibility, err := parseVisibility(req.GetString("visibility", "PRIVATE"))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
create := &store.Memo{
memo, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: store.Visibility(visibility),
}
memo, err := s.store.CreateMemo(ctx, create)
Visibility: visibility,
Payload: buildPayload(content),
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
}
out, err := marshalJSON(memo)
out, err := marshalJSON(storeMemoToJSON(memo))
if err != nil {
return nil, err
}
@ -198,13 +368,9 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil
}
name := req.GetString("name", "")
if name == "" {
return mcp.NewToolResultError("name is required"), nil
}
uid, found := strings.CutPrefix(name, "memos/")
if !found || uid == "" {
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
@ -219,17 +385,29 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
}
update := &store.UpdateMemo{ID: memo.ID}
if content := req.GetString("content", ""); content != "" {
update.Content = &content
args := req.GetArguments()
if v := req.GetString("content", ""); v != "" {
update.Content = &v
update.Payload = buildPayload(v)
}
if vis := req.GetString("visibility", ""); vis != "" {
switch vis {
case "PRIVATE", "PROTECTED", "PUBLIC":
default:
return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil
if v := req.GetString("visibility", ""); v != "" {
vis, err := parseVisibility(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
v := store.Visibility(vis)
update.Visibility = &v
update.Visibility = &vis
}
if v := req.GetString("state", ""); v != "" {
rs, err := parseRowStatus(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update.RowStatus = &rs
}
if _, ok := args["pinned"]; ok {
pinned := req.GetBool("pinned", false)
update.Pinned = &pinned
}
if err := s.store.UpdateMemo(ctx, update); err != nil {
@ -241,7 +419,7 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
}
out, err := marshalJSON(updated)
out, err := marshalJSON(storeMemoToJSON(updated))
if err != nil {
return nil, err
}
@ -254,13 +432,9 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil
}
name := req.GetString("name", "")
if name == "" {
return mcp.NewToolResultError("name is required"), nil
}
uid, found := strings.CutPrefix(name, "memos/")
if !found || uid == "" {
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
@ -277,40 +451,147 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
}
return mcp.NewToolResultText("memo deleted"), nil
return mcp.NewToolResultText(`{"deleted":true}`), nil
}
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
userID := auth.GetUserID(ctx)
query := req.GetString("query", "")
if query == "" {
return mcp.NewToolResultError("query is required"), nil
}
rowStatus := store.Normal
limit := 50
zero := 0
rowStatus := store.Normal
find := &store.FindMemo{
ExcludeComments: true,
RowStatus: &rowStatus,
Limit: &limit,
Offset: &zero,
Filters: []string{
fmt.Sprintf("creator_id == %d", userID),
fmt.Sprintf(`content.contains(%q)`, query),
},
Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)},
}
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil
}
out, err := marshalJSON(memos)
results := make([]memoJSON, len(memos))
for i, m := range memos {
results[i] = storeMemoToJSON(m)
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if parent == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(parent, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
relationType := store.MemoRelationComment
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &parent.ID,
Type: &relationType,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
}
if len(relations) == 0 {
out, _ := marshalJSON([]memoJSON{})
return mcp.NewToolResultText(out), nil
}
commentIDs := make([]int32, len(relations))
for i, r := range relations {
commentIDs[i] = r.MemoID
}
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: commentIDs})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list comments: %v", err)), nil
}
results := make([]memoJSON, 0, len(memos))
for _, m := range memos {
if checkMemoAccess(m, userID) == nil {
results = append(results, storeMemoToJSON(m))
}
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
content := req.GetString("content", "")
if content == "" {
return mcp.NewToolResultError("content is required"), nil
}
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if parent == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(parent, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
comment, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: parent.Visibility,
Payload: buildPayload(content),
ParentUID: &parent.UID,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil
}
if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: comment.ID,
RelatedMemoID: parent.ID,
Type: store.MemoRelationComment,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
}
out, err := marshalJSON(storeMemoToJSON(comment))
if err != nil {
return nil, err
}

View File

@ -0,0 +1,68 @@
package mcp
import (
"context"
"fmt"
"sort"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
), s.handleListTags)
}
type tagEntry struct {
Tag string `json:"tag"`
Count int `json:"count"`
}
func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
rowStatus := store.Normal
find := &store.FindMemo{
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &rowStatus,
}
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
}
counts := make(map[string]int)
for _, m := range memos {
if m.Payload == nil {
continue
}
for _, tag := range m.Payload.Tags {
counts[tag]++
}
}
entries := make([]tagEntry, 0, len(counts))
for tag, count := range counts {
entries = append(entries, tagEntry{Tag: tag, Count: count})
}
sort.Slice(entries, func(i, j int) bool {
if entries[i].Count != entries[j].Count {
return entries[i].Count > entries[j].Count
}
return entries[i].Tag < entries[j].Tag
})
out, err := marshalJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}