From 45bdb34612a82a8b59e500780dc9143c66fde48b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 17 Aug 2025 21:52:32 +0800 Subject: [PATCH] feat: implement AI tag suggestion core logic and API - Add tag suggestion service with OpenAI integration - Add memo service API for tag recommendations - Implement workspace tag management endpoints - Add rate limiting and caching for AI requests Signed-off-by: Chao Liu --- plugin/ai/tag_suggestion.go | 164 ++++++++++++++++++++ proto/api/v1/memo_service.proto | 26 ++++ proto/api/v1/workspace_service.proto | 26 ++++ proto/store/workspace_setting.proto | 11 ++ server/router/api/v1/memo_service.go | 175 ++++++++++++++++++++++ server/router/api/v1/workspace_service.go | 24 +++ server/runner/memopayload/runner.go | 22 +++ store/workspace_setting.go | 14 ++ 8 files changed, 462 insertions(+) create mode 100644 plugin/ai/tag_suggestion.go diff --git a/plugin/ai/tag_suggestion.go b/plugin/ai/tag_suggestion.go new file mode 100644 index 000000000..c877bdc55 --- /dev/null +++ b/plugin/ai/tag_suggestion.go @@ -0,0 +1,164 @@ +package ai + +import ( + "bytes" + "context" + "fmt" + "regexp" + "strings" + "text/template" + "time" +) + +// defaultSystemPrompt contains the core instructions for tag recommendation. +const defaultSystemPrompt = `You are an AI assistant that helps users organize their notes by suggesting relevant tags. + +Your task is to analyze the note content and suggest 3-5 tags that would help categorize and find this note later. + +Guidelines: +- Use the same language as the note content for tag names +- Suggest specific, descriptive tags rather than generic ones +- Focus on the main topics, concepts, and keywords in the content +- If the note mentions specific people, places, projects, or tools, consider including them as tags +- Keep tags concise and practical for search and organization + +Output format: Provide your suggestions as a list in this format: +[tag1](reason for this tag) [tag2](reason for this tag) [tag3](reason for this tag) + +Example: +Note: "Meeting notes from the Q3 planning session. Discussed new mobile app features including dark mode and social login." +Output: [meeting-notes](this is a record of a meeting) [Q3-planning](relates to Q3 quarter planning) [mobile-app](discusses mobile application features) [product-features](about new product functionality)` + +// userMessageTemplate contains only the user data to be analyzed. +const userMessageTemplate = `{{if .ExistingTags}}Existing Tags: {{.ExistingTags}} + +{{end}}Note Content: +{{.NoteContent}}` + +// TagSuggestionRequest represents a tag suggestion request +type TagSuggestionRequest struct { + Content string // The memo content to analyze + UserTags []string // User's frequently used tags (optional) + ExistingTags []string // Tags already in the memo (optional) + SystemPrompt string // Custom system prompt (optional, uses default if empty) +} + +// TagSuggestion represents a single tag suggestion with reason +type TagSuggestion struct { + Tag string + Reason string +} + +// TagSuggestionResponse represents the response from tag suggestion +type TagSuggestionResponse struct { + Tags []TagSuggestion +} + +// GetDefaultSystemPrompt returns the default system prompt for tag recommendation +func GetDefaultSystemPrompt() string { + return defaultSystemPrompt +} + +// SuggestTags suggests tags for memo content using AI +func (c *Client) SuggestTags(ctx context.Context, req *TagSuggestionRequest) (*TagSuggestionResponse, error) { + // Validate request + if req == nil { + return nil, fmt.Errorf("request cannot be nil") + } + + if strings.TrimSpace(req.Content) == "" { + return nil, fmt.Errorf("content cannot be empty") + } + + // Prepare user tags context + userTagsContext := "" + if len(req.UserTags) > 0 { + topTags := req.UserTags + if len(topTags) > 20 { + topTags = topTags[:20] + } + userTagsContext = strings.Join(topTags, ", ") + } + + // Create user message with user data only + userTmpl, err := template.New("userMessage").Parse(userMessageTemplate) + if err != nil { + return nil, fmt.Errorf("failed to parse user message template: %w", err) + } + + var userMsgBuf bytes.Buffer + err = userTmpl.Execute(&userMsgBuf, map[string]string{ + "ExistingTags": userTagsContext, + "NoteContent": req.Content, + }) + if err != nil { + return nil, fmt.Errorf("failed to execute user message template: %w", err) + } + + // Use custom system prompt if provided, otherwise use default + promptToUse := defaultSystemPrompt + if req.SystemPrompt != "" { + promptToUse = req.SystemPrompt + } + + // Make AI request with separated system and user messages + chatReq := &ChatRequest{ + Messages: []Message{ + {Role: "system", Content: promptToUse}, + {Role: "user", Content: userMsgBuf.String()}, + }, + MaxTokens: 8192, + Temperature: 0.8, + Timeout: 15 * time.Second, + } + + response, err := c.Chat(ctx, chatReq) + if err != nil { + return nil, fmt.Errorf("failed to get AI response for tag suggestion: %w", err) + } + + tags := c.parseTagResponse(response.Content) + + // Validate that we got some meaningful response + if len(tags) == 0 { + return nil, fmt.Errorf("AI returned no valid tag suggestions") + } + + return &TagSuggestionResponse{ + Tags: tags, + }, nil +} + +// parseTagResponse parses AI response for [tag](reason) patterns +func (c *Client) parseTagResponse(responseText string) []TagSuggestion { + tags := make([]TagSuggestion, 0) + + // Match [tag](reason) format using regex across response + pattern := `\[([^\]]+)\]\(([^)]+)\)` + re := regexp.MustCompile(pattern) + matches := re.FindAllStringSubmatch(responseText, -1) + + for _, match := range matches { + if len(match) >= 3 { + tag := strings.TrimSpace(match[1]) + reason := strings.TrimSpace(match[2]) + + // Remove # prefix if AI included it + tag = strings.TrimPrefix(tag, "#") + + // Clean and validate tag + if tag != "" && len(tag) <= 100 { + // Limit reason length + if len(reason) > 100 { + reason = reason[:100] + "..." + } + tags = append(tags, TagSuggestion{ + Tag: tag, + Reason: reason, + }) + } + } + } + + return tags +} diff --git a/proto/api/v1/memo_service.proto b/proto/api/v1/memo_service.proto index 3a9bb4612..12f068c7d 100644 --- a/proto/api/v1/memo_service.proto +++ b/proto/api/v1/memo_service.proto @@ -120,6 +120,13 @@ service MemoService { option (google.api.http) = {delete: "/api/v1/{name=reactions/*}"}; option (google.api.method_signature) = "name"; } + // SuggestMemoTags suggests tags for memo content. + rpc SuggestMemoTags(SuggestMemoTagsRequest) returns (SuggestMemoTagsResponse) { + option (google.api.http) = { + post: "/api/v1/memos:suggest-tags" + body: "*" + }; + } } enum Visibility { @@ -577,3 +584,22 @@ message DeleteMemoReactionRequest { (google.api.resource_reference) = {type: "memos.api.v1/Reaction"} ]; } + +message SuggestMemoTagsRequest { + // Required. The content of the memo for tag suggestion. + string content = 1 [(google.api.field_behavior) = REQUIRED]; + // Optional. The existing tags for the memo. + repeated string existing_tags = 2 [(google.api.field_behavior) = OPTIONAL]; +} + +message TagSuggestion { + // The suggested tag name. + string tag = 1 [(google.api.field_behavior) = REQUIRED]; + // The reason why this tag is recommended. + string reason = 2 [(google.api.field_behavior) = REQUIRED]; +} + +message SuggestMemoTagsResponse { + // The suggested tags with reasons for the memo. + repeated TagSuggestion suggested_tags = 1; +} diff --git a/proto/api/v1/workspace_service.proto b/proto/api/v1/workspace_service.proto index bb44609a7..75737ae12 100644 --- a/proto/api/v1/workspace_service.proto +++ b/proto/api/v1/workspace_service.proto @@ -30,6 +30,11 @@ service WorkspaceService { }; option (google.api.method_signature) = "setting,update_mask"; } + + // Gets the default system prompt for AI tag recommendations. + rpc GetDefaultTagRecommendationPrompt(GetDefaultTagRecommendationPromptRequest) returns (GetDefaultTagRecommendationPromptResponse) { + option (google.api.http) = {get: "/api/v1/workspace/ai/tag-recommendation/default-prompt"}; + } } // Workspace profile message containing basic workspace information. @@ -186,6 +191,18 @@ message WorkspaceSetting { string model = 4; // timeout_seconds is the timeout for AI requests in seconds. int32 timeout_seconds = 5; + // tag_recommendation contains tag recommendation specific settings. + TagRecommendationConfig tag_recommendation = 6; + } + + // Tag recommendation configuration. + message TagRecommendationConfig { + // enabled controls whether tag recommendation is enabled. + bool enabled = 1; + // system_prompt is the custom system prompt for tag recommendation. + string system_prompt = 2; + // requests_per_minute is the rate limit for tag recommendation requests. + int32 requests_per_minute = 3; } } @@ -207,3 +224,12 @@ message UpdateWorkspaceSettingRequest { // The list of fields to update. google.protobuf.FieldMask update_mask = 2 [(google.api.field_behavior) = OPTIONAL]; } + +// Request message for GetDefaultTagRecommendationPrompt method. +message GetDefaultTagRecommendationPromptRequest {} + +// Response message for GetDefaultTagRecommendationPrompt method. +message GetDefaultTagRecommendationPromptResponse { + // The default system prompt for tag recommendation. + string system_prompt = 1; +} diff --git a/proto/store/workspace_setting.proto b/proto/store/workspace_setting.proto index a4ceba80c..272522f79 100644 --- a/proto/store/workspace_setting.proto +++ b/proto/store/workspace_setting.proto @@ -130,4 +130,15 @@ message WorkspaceAISetting { string model = 4; // timeout_seconds is the timeout for AI requests in seconds. int32 timeout_seconds = 5; + // tag_recommendation contains tag recommendation specific settings. + TagRecommendationConfig tag_recommendation = 6; +} + +message TagRecommendationConfig { + // enabled controls whether tag recommendation is enabled. + bool enabled = 1; + // system_prompt is the custom system prompt for tag recommendation. + string system_prompt = 2; + // requests_per_minute is the rate limit for tag recommendation requests. + int32 requests_per_minute = 3; } diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 57f6bfab8..0fc6e5248 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "log/slog" + "sort" "strings" + "sync" "time" "unicode/utf8" @@ -19,6 +21,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" + "github.com/usememos/memos/plugin/ai" "github.com/usememos/memos/plugin/webhook" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" @@ -26,6 +29,47 @@ import ( "github.com/usememos/memos/store" ) +// tagRecommendationRateLimit tracks user request counts for rate limiting +type tagRecommendationRateLimit struct { + mu sync.RWMutex + requests map[int32][]time.Time // userID -> request times +} + +var tagRateLimit = &tagRecommendationRateLimit{ + requests: make(map[int32][]time.Time), +} + +// checkRateLimit checks if user has exceeded rate limit +func (rl *tagRecommendationRateLimit) checkRateLimit(userID int32, maxRequestsPerMinute int32) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + oneMinuteAgo := now.Add(-time.Minute) + + // Get user's request history + requests := rl.requests[userID] + + // Remove requests older than 1 minute + var recentRequests []time.Time + for _, reqTime := range requests { + if reqTime.After(oneMinuteAgo) { + recentRequests = append(recentRequests, reqTime) + } + } + + // Check if user has exceeded rate limit + if int32(len(recentRequests)) >= maxRequestsPerMinute { + return false // Rate limit exceeded + } + + // Add current request + recentRequests = append(recentRequests, now) + rl.requests[userID] = recentRequests + + return true // Within rate limit +} + func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) { user, err := s.GetCurrentUser(ctx) if err != nil { @@ -905,3 +949,134 @@ func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo) return nil } + +func (s *APIV1Service) SuggestMemoTags(ctx context.Context, request *v1pb.SuggestMemoTagsRequest) (*v1pb.SuggestMemoTagsResponse, error) { + // Validate request + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "request cannot be nil") + } + + if strings.TrimSpace(request.Content) == "" { + return nil, status.Errorf(codes.InvalidArgument, "content cannot be empty") + } + + user, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user") + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "authentication required") + } + + // Validate content length (minimum 15 characters as specified in design) + if utf8.RuneCountInString(strings.TrimSpace(request.Content)) < 15 { + return nil, status.Errorf(codes.InvalidArgument, "content too short for tag recommendation (minimum 15 characters)") + } + + // Get user's existing tags from statistics + userStats, err := s.GetUserStats(ctx, &v1pb.GetUserStatsRequest{ + Name: fmt.Sprintf("users/%d", user.ID), + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user stats: %v", err) + } + + // Extract existing tags from user's history, sorted by frequency + existingTags := make([]string, 0, len(userStats.TagCount)) + for tag := range userStats.TagCount { + existingTags = append(existingTags, tag) + } + + // Sort tags by frequency (most used first) + sort.Slice(existingTags, func(i, j int) bool { + return userStats.TagCount[existingTags[i]] > userStats.TagCount[existingTags[j]] + }) + + // Extract existing tags from memo content and combine with request existing tags + existingTagSet := make(map[string]bool) + // Add tags from request + for _, tag := range request.ExistingTags { + existingTagSet[tag] = true + } + // Add tags from content (extract #tag patterns) + existingContentTags := memopayload.ExtractTagsFromContent(request.Content) + for _, tag := range existingContentTags { + existingTagSet[tag] = true + } + + // Try AI-based recommendation first + // Load AI configuration from database first, fallback to environment variables + aiSetting, err := s.Store.GetWorkspaceAISetting(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get AI setting: %v", err) + } + + // Check if tag recommendation is enabled + if !aiSetting.EnableAi || aiSetting.TagRecommendation == nil || !aiSetting.TagRecommendation.Enabled { + return &v1pb.SuggestMemoTagsResponse{ + SuggestedTags: []*v1pb.TagSuggestion{}, + }, nil + } + + // Check rate limit + currentUser, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err) + } + + maxRequestsPerMinute := aiSetting.TagRecommendation.RequestsPerMinute + if maxRequestsPerMinute <= 0 { + maxRequestsPerMinute = 10 // Default rate limit + } + + if !tagRateLimit.checkRateLimit(currentUser.ID, maxRequestsPerMinute) { + return nil, status.Errorf(codes.ResourceExhausted, "标签推荐请求频率超限,每分钟最多 %d 次", maxRequestsPerMinute) + } + + aiConfig := ai.LoadConfigFromDatabase(aiSetting) + if !aiConfig.IsConfigured() { + // Fallback to environment variables if database config is not complete + aiConfig = aiConfig.MergeWithEnv() + } + + if aiConfig.IsConfigured() { + aiClient, err := ai.NewClient(aiConfig) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create AI client: %v", err) + } + + aiRequest := &ai.TagSuggestionRequest{ + Content: request.Content, + UserTags: existingTags, + ExistingTags: request.ExistingTags, + SystemPrompt: aiSetting.TagRecommendation.SystemPrompt, + } + + aiResponse, err := aiClient.SuggestTags(ctx, aiRequest) + if err != nil { + // Log error - no fallback since we removed simple algorithm + slog.Warn("AI tag suggestion failed, returning empty list", "error", err) + } else { + // Filter out existing tags and convert to proto format + filteredTags := make([]*v1pb.TagSuggestion, 0) + for _, tagSuggestion := range aiResponse.Tags { + if !existingTagSet[tagSuggestion.Tag] && len(filteredTags) < 5 { + filteredTags = append(filteredTags, &v1pb.TagSuggestion{ + Tag: tagSuggestion.Tag, + Reason: tagSuggestion.Reason, + }) + } + } + if len(filteredTags) > 0 { + return &v1pb.SuggestMemoTagsResponse{ + SuggestedTags: filteredTags, + }, nil + } + } + } + + // No AI configured - return empty suggestions + return &v1pb.SuggestMemoTagsResponse{ + SuggestedTags: []*v1pb.TagSuggestion{}, + }, nil +} diff --git a/server/router/api/v1/workspace_service.go b/server/router/api/v1/workspace_service.go index 9c1c4c95a..10bd688b2 100644 --- a/server/router/api/v1/workspace_service.go +++ b/server/router/api/v1/workspace_service.go @@ -8,6 +8,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/usememos/memos/plugin/ai" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" @@ -308,6 +309,14 @@ func convertWorkspaceAISettingFromStore(setting *storepb.WorkspaceAISetting) *v1 TimeoutSeconds: setting.TimeoutSeconds, } + if setting.TagRecommendation != nil { + result.TagRecommendation = &v1pb.WorkspaceSetting_TagRecommendationConfig{ + Enabled: setting.TagRecommendation.Enabled, + SystemPrompt: setting.TagRecommendation.SystemPrompt, + RequestsPerMinute: setting.TagRecommendation.RequestsPerMinute, + } + } + return result } @@ -330,6 +339,14 @@ func convertWorkspaceAISettingToStore(setting *v1pb.WorkspaceSetting_AiSetting) TimeoutSeconds: setting.TimeoutSeconds, } + if setting.TagRecommendation != nil { + result.TagRecommendation = &storepb.TagRecommendationConfig{ + Enabled: setting.TagRecommendation.Enabled, + SystemPrompt: setting.TagRecommendation.SystemPrompt, + RequestsPerMinute: setting.TagRecommendation.RequestsPerMinute, + } + } + return result } @@ -354,3 +371,10 @@ func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) ownerCache = convertUserFromStore(user) return ownerCache, nil } + +// GetDefaultTagRecommendationPrompt returns the default system prompt for AI tag recommendations. +func (s *APIV1Service) GetDefaultTagRecommendationPrompt(ctx context.Context, _ *v1pb.GetDefaultTagRecommendationPromptRequest) (*v1pb.GetDefaultTagRecommendationPromptResponse, error) { + return &v1pb.GetDefaultTagRecommendationPromptResponse{ + SystemPrompt: ai.GetDefaultSystemPrompt(), + }, nil +} diff --git a/server/runner/memopayload/runner.go b/server/runner/memopayload/runner.go index 141110d1d..7f04d05df 100644 --- a/server/runner/memopayload/runner.go +++ b/server/runner/memopayload/runner.go @@ -109,6 +109,28 @@ func RebuildMemoPayload(memo *store.Memo) error { return nil } +// ExtractTagsFromContent extracts tags from content string using the same logic as RebuildMemoPayload +// This function is exported for use in other packages (e.g., for tag recommendations) +func ExtractTagsFromContent(content string) []string { + nodes, err := parser.Parse(tokenizer.Tokenize(content)) + if err != nil { + return []string{} + } + + tags := []string{} + TraverseASTNodes(nodes, func(node ast.Node) { + switch n := node.(type) { + case *ast.Tag: + tag := n.Content + if !slices.Contains(tags, tag) { + tags = append(tags, tag) + } + } + }) + + return tags +} + func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) { for _, node := range nodes { fn(node) diff --git a/store/workspace_setting.go b/store/workspace_setting.go index 64612b962..f2457d4d1 100644 --- a/store/workspace_setting.go +++ b/store/workspace_setting.go @@ -238,6 +238,15 @@ func (s *Store) GetWorkspaceAISetting(ctx context.Context) (*storepb.WorkspaceAI workspaceAISetting.TimeoutSeconds = defaultAITimeoutSeconds } + // Set default tag recommendation config if not configured + if workspaceAISetting.TagRecommendation == nil { + workspaceAISetting.TagRecommendation = &storepb.TagRecommendationConfig{ + Enabled: workspaceAISetting.EnableAi, + SystemPrompt: "", + RequestsPerMinute: 10, + } + } + s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_AI.String(), &storepb.WorkspaceSetting{ Key: storepb.WorkspaceSettingKey_AI, Value: &storepb.WorkspaceSetting_AiSetting{AiSetting: workspaceAISetting}, @@ -267,6 +276,11 @@ func loadAISettingFromEnv() *storepb.WorkspaceAISetting { ApiKey: apiKey, Model: model, TimeoutSeconds: timeoutSeconds, + TagRecommendation: &storepb.TagRecommendationConfig{ + Enabled: enableAI, + SystemPrompt: "", + RequestsPerMinute: 10, + }, } }