mirror of https://github.com/usememos/memos.git
319 lines
9.0 KiB
Go
319 lines
9.0 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/lithammer/shortuuid/v4"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
mcpserver "github.com/mark3labs/mcp-go/server"
|
|
|
|
"github.com/usememos/memos/server/auth"
|
|
"github.com/usememos/memos/store"
|
|
)
|
|
|
|
func extractUserID(ctx context.Context) (int32, error) {
|
|
id := auth.GetUserID(ctx)
|
|
if id == 0 {
|
|
return 0, errors.New("unauthenticated")
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func marshalJSON(v any) (string, error) {
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(b), nil
|
|
}
|
|
|
|
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
|
|
listTool := mcp.NewTool("list_memos",
|
|
mcp.WithDescription("List the authenticated user's memos"),
|
|
mcp.WithNumber("page_size", mcp.Description("Max memos to return, default 20")),
|
|
mcp.WithString("filter", mcp.Description(`CEL filter expression, e.g. content.contains("keyword")`)),
|
|
)
|
|
mcpSrv.AddTool(listTool, s.handleListMemos)
|
|
|
|
getTool := mcp.NewTool("get_memo",
|
|
mcp.WithDescription("Get a single memo by resource name"),
|
|
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
|
)
|
|
mcpSrv.AddTool(getTool, s.handleGetMemo)
|
|
|
|
createTool := mcp.NewTool("create_memo",
|
|
mcp.WithDescription("Create a new memo"),
|
|
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content")),
|
|
mcp.WithString("visibility",
|
|
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
|
|
mcp.Description("Visibility: PRIVATE (default), PROTECTED, or PUBLIC"),
|
|
),
|
|
)
|
|
mcpSrv.AddTool(createTool, s.handleCreateMemo)
|
|
|
|
updateTool := mcp.NewTool("update_memo",
|
|
mcp.WithDescription("Update a memo's content or visibility"),
|
|
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
|
mcp.WithString("content", mcp.Description("New content (omit to leave unchanged)")),
|
|
mcp.WithString("visibility",
|
|
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
|
|
mcp.Description("New visibility (omit to leave unchanged)"),
|
|
),
|
|
)
|
|
mcpSrv.AddTool(updateTool, s.handleUpdateMemo)
|
|
|
|
deleteTool := mcp.NewTool("delete_memo",
|
|
mcp.WithDescription("Delete a memo"),
|
|
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
|
|
)
|
|
mcpSrv.AddTool(deleteTool, s.handleDeleteMemo)
|
|
|
|
searchTool := mcp.NewTool("search_memos",
|
|
mcp.WithDescription("Search memo content using a text query"),
|
|
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search in memo content")),
|
|
)
|
|
mcpSrv.AddTool(searchTool, s.handleSearchMemos)
|
|
}
|
|
|
|
func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID, err := extractUserID(ctx)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(err.Error()), nil
|
|
}
|
|
|
|
pageSize := req.GetInt("page_size", 20)
|
|
if pageSize <= 0 {
|
|
pageSize = 20
|
|
}
|
|
if pageSize > 100 {
|
|
pageSize = 100
|
|
}
|
|
filterExpr := req.GetString("filter", "")
|
|
|
|
rowStatus := store.Normal
|
|
limitPlusOne := pageSize + 1
|
|
zero := 0
|
|
find := &store.FindMemo{
|
|
CreatorID: &userID,
|
|
ExcludeComments: true,
|
|
RowStatus: &rowStatus,
|
|
Limit: &limitPlusOne,
|
|
Offset: &zero,
|
|
}
|
|
if filterExpr != "" {
|
|
find.Filters = append(find.Filters, filterExpr)
|
|
}
|
|
|
|
memos, err := s.store.ListMemos(ctx, find)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
|
|
}
|
|
if len(memos) == limitPlusOne {
|
|
memos = memos[:pageSize]
|
|
}
|
|
|
|
out, err := marshalJSON(memos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mcp.NewToolResultText(out), nil
|
|
}
|
|
|
|
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID, err := extractUserID(ctx)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(err.Error()), nil
|
|
}
|
|
|
|
name := req.GetString("name", "")
|
|
if name == "" {
|
|
return mcp.NewToolResultError("name is required"), nil
|
|
}
|
|
uid, found := strings.CutPrefix(name, "memos/")
|
|
if !found || uid == "" {
|
|
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
|
|
}
|
|
|
|
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
|
}
|
|
if memo == nil {
|
|
return mcp.NewToolResultError("memo not found"), nil
|
|
}
|
|
if memo.Visibility == store.Private && memo.CreatorID != userID {
|
|
return mcp.NewToolResultError("permission denied"), nil
|
|
}
|
|
|
|
out, err := marshalJSON(memo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mcp.NewToolResultText(out), nil
|
|
}
|
|
|
|
func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID, err := extractUserID(ctx)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(err.Error()), nil
|
|
}
|
|
|
|
content := req.GetString("content", "")
|
|
if content == "" {
|
|
return mcp.NewToolResultError("content is required"), nil
|
|
}
|
|
|
|
visibility := req.GetString("visibility", "PRIVATE")
|
|
switch visibility {
|
|
case "PRIVATE", "PROTECTED", "PUBLIC":
|
|
default:
|
|
return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil
|
|
}
|
|
|
|
create := &store.Memo{
|
|
UID: shortuuid.New(),
|
|
CreatorID: userID,
|
|
Content: content,
|
|
Visibility: store.Visibility(visibility),
|
|
}
|
|
memo, err := s.store.CreateMemo(ctx, create)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
|
|
}
|
|
|
|
out, err := marshalJSON(memo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mcp.NewToolResultText(out), nil
|
|
}
|
|
|
|
func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID, err := extractUserID(ctx)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(err.Error()), nil
|
|
}
|
|
|
|
name := req.GetString("name", "")
|
|
if name == "" {
|
|
return mcp.NewToolResultError("name is required"), nil
|
|
}
|
|
uid, found := strings.CutPrefix(name, "memos/")
|
|
if !found || uid == "" {
|
|
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
|
|
}
|
|
|
|
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
|
}
|
|
if memo == nil {
|
|
return mcp.NewToolResultError("memo not found"), nil
|
|
}
|
|
if memo.CreatorID != userID {
|
|
return mcp.NewToolResultError("permission denied"), nil
|
|
}
|
|
|
|
update := &store.UpdateMemo{ID: memo.ID}
|
|
if content := req.GetString("content", ""); content != "" {
|
|
update.Content = &content
|
|
}
|
|
if vis := req.GetString("visibility", ""); vis != "" {
|
|
switch vis {
|
|
case "PRIVATE", "PROTECTED", "PUBLIC":
|
|
default:
|
|
return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil
|
|
}
|
|
v := store.Visibility(vis)
|
|
update.Visibility = &v
|
|
}
|
|
|
|
if err := s.store.UpdateMemo(ctx, update); err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil
|
|
}
|
|
|
|
updated, err := s.store.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
|
|
}
|
|
|
|
out, err := marshalJSON(updated)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mcp.NewToolResultText(out), nil
|
|
}
|
|
|
|
func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID, err := extractUserID(ctx)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(err.Error()), nil
|
|
}
|
|
|
|
name := req.GetString("name", "")
|
|
if name == "" {
|
|
return mcp.NewToolResultError("name is required"), nil
|
|
}
|
|
uid, found := strings.CutPrefix(name, "memos/")
|
|
if !found || uid == "" {
|
|
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
|
|
}
|
|
|
|
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
|
|
}
|
|
if memo == nil {
|
|
return mcp.NewToolResultError("memo not found"), nil
|
|
}
|
|
if memo.CreatorID != userID {
|
|
return mcp.NewToolResultError("permission denied"), nil
|
|
}
|
|
|
|
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
|
|
}
|
|
return mcp.NewToolResultText("memo deleted"), nil
|
|
}
|
|
|
|
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID, err := extractUserID(ctx)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(err.Error()), nil
|
|
}
|
|
|
|
query := req.GetString("query", "")
|
|
if query == "" {
|
|
return mcp.NewToolResultError("query is required"), nil
|
|
}
|
|
|
|
rowStatus := store.Normal
|
|
limit := 50
|
|
zero := 0
|
|
find := &store.FindMemo{
|
|
ExcludeComments: true,
|
|
RowStatus: &rowStatus,
|
|
Limit: &limit,
|
|
Offset: &zero,
|
|
Filters: []string{
|
|
fmt.Sprintf("creator_id == %d", userID),
|
|
fmt.Sprintf(`content.contains(%q)`, query),
|
|
},
|
|
}
|
|
|
|
memos, err := s.store.ListMemos(ctx, find)
|
|
if err != nil {
|
|
return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil
|
|
}
|
|
|
|
out, err := marshalJSON(memos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mcp.NewToolResultText(out), nil
|
|
}
|