feat: implement AI tag suggestion core logic and API

- Add tag suggestion service with OpenAI integration
- Add memo service API for tag recommendations
- Implement workspace tag management endpoints
- Add rate limiting and caching for AI requests

Signed-off-by: Chao Liu <chaoliu719@gmail.com>
This commit is contained in:
Chao Liu 2025-08-17 21:52:32 +08:00 committed by ChaoLiu
parent 982ebb5876
commit 45bdb34612
8 changed files with 462 additions and 0 deletions

164
plugin/ai/tag_suggestion.go Normal file
View File

@ -0,0 +1,164 @@
package ai
import (
"bytes"
"context"
"fmt"
"regexp"
"strings"
"text/template"
"time"
)
// defaultSystemPrompt contains the core instructions for tag recommendation.
const defaultSystemPrompt = `You are an AI assistant that helps users organize their notes by suggesting relevant tags.
Your task is to analyze the note content and suggest 3-5 tags that would help categorize and find this note later.
Guidelines:
- Use the same language as the note content for tag names
- Suggest specific, descriptive tags rather than generic ones
- Focus on the main topics, concepts, and keywords in the content
- If the note mentions specific people, places, projects, or tools, consider including them as tags
- Keep tags concise and practical for search and organization
Output format: Provide your suggestions as a list in this format:
[tag1](reason for this tag) [tag2](reason for this tag) [tag3](reason for this tag)
Example:
Note: "Meeting notes from the Q3 planning session. Discussed new mobile app features including dark mode and social login."
Output: [meeting-notes](this is a record of a meeting) [Q3-planning](relates to Q3 quarter planning) [mobile-app](discusses mobile application features) [product-features](about new product functionality)`
// userMessageTemplate contains only the user data to be analyzed.
const userMessageTemplate = `{{if .ExistingTags}}Existing Tags: {{.ExistingTags}}
{{end}}Note Content:
{{.NoteContent}}`
// TagSuggestionRequest represents a tag suggestion request
type TagSuggestionRequest struct {
Content string // The memo content to analyze
UserTags []string // User's frequently used tags (optional)
ExistingTags []string // Tags already in the memo (optional)
SystemPrompt string // Custom system prompt (optional, uses default if empty)
}
// TagSuggestion represents a single tag suggestion with reason
type TagSuggestion struct {
Tag string
Reason string
}
// TagSuggestionResponse represents the response from tag suggestion
type TagSuggestionResponse struct {
Tags []TagSuggestion
}
// GetDefaultSystemPrompt returns the default system prompt for tag recommendation
func GetDefaultSystemPrompt() string {
return defaultSystemPrompt
}
// SuggestTags suggests tags for memo content using AI
func (c *Client) SuggestTags(ctx context.Context, req *TagSuggestionRequest) (*TagSuggestionResponse, error) {
// Validate request
if req == nil {
return nil, fmt.Errorf("request cannot be nil")
}
if strings.TrimSpace(req.Content) == "" {
return nil, fmt.Errorf("content cannot be empty")
}
// Prepare user tags context
userTagsContext := ""
if len(req.UserTags) > 0 {
topTags := req.UserTags
if len(topTags) > 20 {
topTags = topTags[:20]
}
userTagsContext = strings.Join(topTags, ", ")
}
// Create user message with user data only
userTmpl, err := template.New("userMessage").Parse(userMessageTemplate)
if err != nil {
return nil, fmt.Errorf("failed to parse user message template: %w", err)
}
var userMsgBuf bytes.Buffer
err = userTmpl.Execute(&userMsgBuf, map[string]string{
"ExistingTags": userTagsContext,
"NoteContent": req.Content,
})
if err != nil {
return nil, fmt.Errorf("failed to execute user message template: %w", err)
}
// Use custom system prompt if provided, otherwise use default
promptToUse := defaultSystemPrompt
if req.SystemPrompt != "" {
promptToUse = req.SystemPrompt
}
// Make AI request with separated system and user messages
chatReq := &ChatRequest{
Messages: []Message{
{Role: "system", Content: promptToUse},
{Role: "user", Content: userMsgBuf.String()},
},
MaxTokens: 8192,
Temperature: 0.8,
Timeout: 15 * time.Second,
}
response, err := c.Chat(ctx, chatReq)
if err != nil {
return nil, fmt.Errorf("failed to get AI response for tag suggestion: %w", err)
}
tags := c.parseTagResponse(response.Content)
// Validate that we got some meaningful response
if len(tags) == 0 {
return nil, fmt.Errorf("AI returned no valid tag suggestions")
}
return &TagSuggestionResponse{
Tags: tags,
}, nil
}
// parseTagResponse parses AI response for [tag](reason) patterns
func (c *Client) parseTagResponse(responseText string) []TagSuggestion {
tags := make([]TagSuggestion, 0)
// Match [tag](reason) format using regex across response
pattern := `\[([^\]]+)\]\(([^)]+)\)`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(responseText, -1)
for _, match := range matches {
if len(match) >= 3 {
tag := strings.TrimSpace(match[1])
reason := strings.TrimSpace(match[2])
// Remove # prefix if AI included it
tag = strings.TrimPrefix(tag, "#")
// Clean and validate tag
if tag != "" && len(tag) <= 100 {
// Limit reason length
if len(reason) > 100 {
reason = reason[:100] + "..."
}
tags = append(tags, TagSuggestion{
Tag: tag,
Reason: reason,
})
}
}
}
return tags
}

View File

@ -120,6 +120,13 @@ service MemoService {
option (google.api.http) = {delete: "/api/v1/{name=reactions/*}"};
option (google.api.method_signature) = "name";
}
// SuggestMemoTags suggests tags for memo content.
rpc SuggestMemoTags(SuggestMemoTagsRequest) returns (SuggestMemoTagsResponse) {
option (google.api.http) = {
post: "/api/v1/memos:suggest-tags"
body: "*"
};
}
}
enum Visibility {
@ -577,3 +584,22 @@ message DeleteMemoReactionRequest {
(google.api.resource_reference) = {type: "memos.api.v1/Reaction"}
];
}
message SuggestMemoTagsRequest {
// Required. The content of the memo for tag suggestion.
string content = 1 [(google.api.field_behavior) = REQUIRED];
// Optional. The existing tags for the memo.
repeated string existing_tags = 2 [(google.api.field_behavior) = OPTIONAL];
}
message TagSuggestion {
// The suggested tag name.
string tag = 1 [(google.api.field_behavior) = REQUIRED];
// The reason why this tag is recommended.
string reason = 2 [(google.api.field_behavior) = REQUIRED];
}
message SuggestMemoTagsResponse {
// The suggested tags with reasons for the memo.
repeated TagSuggestion suggested_tags = 1;
}

View File

@ -30,6 +30,11 @@ service WorkspaceService {
};
option (google.api.method_signature) = "setting,update_mask";
}
// Gets the default system prompt for AI tag recommendations.
rpc GetDefaultTagRecommendationPrompt(GetDefaultTagRecommendationPromptRequest) returns (GetDefaultTagRecommendationPromptResponse) {
option (google.api.http) = {get: "/api/v1/workspace/ai/tag-recommendation/default-prompt"};
}
}
// Workspace profile message containing basic workspace information.
@ -186,6 +191,18 @@ message WorkspaceSetting {
string model = 4;
// timeout_seconds is the timeout for AI requests in seconds.
int32 timeout_seconds = 5;
// tag_recommendation contains tag recommendation specific settings.
TagRecommendationConfig tag_recommendation = 6;
}
// Tag recommendation configuration.
message TagRecommendationConfig {
// enabled controls whether tag recommendation is enabled.
bool enabled = 1;
// system_prompt is the custom system prompt for tag recommendation.
string system_prompt = 2;
// requests_per_minute is the rate limit for tag recommendation requests.
int32 requests_per_minute = 3;
}
}
@ -207,3 +224,12 @@ message UpdateWorkspaceSettingRequest {
// The list of fields to update.
google.protobuf.FieldMask update_mask = 2 [(google.api.field_behavior) = OPTIONAL];
}
// Request message for GetDefaultTagRecommendationPrompt method.
message GetDefaultTagRecommendationPromptRequest {}
// Response message for GetDefaultTagRecommendationPrompt method.
message GetDefaultTagRecommendationPromptResponse {
// The default system prompt for tag recommendation.
string system_prompt = 1;
}

View File

@ -130,4 +130,15 @@ message WorkspaceAISetting {
string model = 4;
// timeout_seconds is the timeout for AI requests in seconds.
int32 timeout_seconds = 5;
// tag_recommendation contains tag recommendation specific settings.
TagRecommendationConfig tag_recommendation = 6;
}
message TagRecommendationConfig {
// enabled controls whether tag recommendation is enabled.
bool enabled = 1;
// system_prompt is the custom system prompt for tag recommendation.
string system_prompt = 2;
// requests_per_minute is the rate limit for tag recommendation requests.
int32 requests_per_minute = 3;
}

View File

@ -4,7 +4,9 @@ import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"sync"
"time"
"unicode/utf8"
@ -19,6 +21,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/plugin/ai"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
@ -26,6 +29,47 @@ import (
"github.com/usememos/memos/store"
)
// tagRecommendationRateLimit tracks user request counts for rate limiting
type tagRecommendationRateLimit struct {
mu sync.RWMutex
requests map[int32][]time.Time // userID -> request times
}
var tagRateLimit = &tagRecommendationRateLimit{
requests: make(map[int32][]time.Time),
}
// checkRateLimit checks if user has exceeded rate limit
func (rl *tagRecommendationRateLimit) checkRateLimit(userID int32, maxRequestsPerMinute int32) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
oneMinuteAgo := now.Add(-time.Minute)
// Get user's request history
requests := rl.requests[userID]
// Remove requests older than 1 minute
var recentRequests []time.Time
for _, reqTime := range requests {
if reqTime.After(oneMinuteAgo) {
recentRequests = append(recentRequests, reqTime)
}
}
// Check if user has exceeded rate limit
if int32(len(recentRequests)) >= maxRequestsPerMinute {
return false // Rate limit exceeded
}
// Add current request
recentRequests = append(recentRequests, now)
rl.requests[userID] = recentRequests
return true // Within rate limit
}
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
@ -905,3 +949,134 @@ func (*APIV1Service) parseMemoOrderBy(orderBy string, memoFind *store.FindMemo)
return nil
}
func (s *APIV1Service) SuggestMemoTags(ctx context.Context, request *v1pb.SuggestMemoTagsRequest) (*v1pb.SuggestMemoTagsResponse, error) {
// Validate request
if request == nil {
return nil, status.Errorf(codes.InvalidArgument, "request cannot be nil")
}
if strings.TrimSpace(request.Content) == "" {
return nil, status.Errorf(codes.InvalidArgument, "content cannot be empty")
}
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
}
// Validate content length (minimum 15 characters as specified in design)
if utf8.RuneCountInString(strings.TrimSpace(request.Content)) < 15 {
return nil, status.Errorf(codes.InvalidArgument, "content too short for tag recommendation (minimum 15 characters)")
}
// Get user's existing tags from statistics
userStats, err := s.GetUserStats(ctx, &v1pb.GetUserStatsRequest{
Name: fmt.Sprintf("users/%d", user.ID),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user stats: %v", err)
}
// Extract existing tags from user's history, sorted by frequency
existingTags := make([]string, 0, len(userStats.TagCount))
for tag := range userStats.TagCount {
existingTags = append(existingTags, tag)
}
// Sort tags by frequency (most used first)
sort.Slice(existingTags, func(i, j int) bool {
return userStats.TagCount[existingTags[i]] > userStats.TagCount[existingTags[j]]
})
// Extract existing tags from memo content and combine with request existing tags
existingTagSet := make(map[string]bool)
// Add tags from request
for _, tag := range request.ExistingTags {
existingTagSet[tag] = true
}
// Add tags from content (extract #tag patterns)
existingContentTags := memopayload.ExtractTagsFromContent(request.Content)
for _, tag := range existingContentTags {
existingTagSet[tag] = true
}
// Try AI-based recommendation first
// Load AI configuration from database first, fallback to environment variables
aiSetting, err := s.Store.GetWorkspaceAISetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get AI setting: %v", err)
}
// Check if tag recommendation is enabled
if !aiSetting.EnableAi || aiSetting.TagRecommendation == nil || !aiSetting.TagRecommendation.Enabled {
return &v1pb.SuggestMemoTagsResponse{
SuggestedTags: []*v1pb.TagSuggestion{},
}, nil
}
// Check rate limit
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
maxRequestsPerMinute := aiSetting.TagRecommendation.RequestsPerMinute
if maxRequestsPerMinute <= 0 {
maxRequestsPerMinute = 10 // Default rate limit
}
if !tagRateLimit.checkRateLimit(currentUser.ID, maxRequestsPerMinute) {
return nil, status.Errorf(codes.ResourceExhausted, "标签推荐请求频率超限,每分钟最多 %d 次", maxRequestsPerMinute)
}
aiConfig := ai.LoadConfigFromDatabase(aiSetting)
if !aiConfig.IsConfigured() {
// Fallback to environment variables if database config is not complete
aiConfig = aiConfig.MergeWithEnv()
}
if aiConfig.IsConfigured() {
aiClient, err := ai.NewClient(aiConfig)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create AI client: %v", err)
}
aiRequest := &ai.TagSuggestionRequest{
Content: request.Content,
UserTags: existingTags,
ExistingTags: request.ExistingTags,
SystemPrompt: aiSetting.TagRecommendation.SystemPrompt,
}
aiResponse, err := aiClient.SuggestTags(ctx, aiRequest)
if err != nil {
// Log error - no fallback since we removed simple algorithm
slog.Warn("AI tag suggestion failed, returning empty list", "error", err)
} else {
// Filter out existing tags and convert to proto format
filteredTags := make([]*v1pb.TagSuggestion, 0)
for _, tagSuggestion := range aiResponse.Tags {
if !existingTagSet[tagSuggestion.Tag] && len(filteredTags) < 5 {
filteredTags = append(filteredTags, &v1pb.TagSuggestion{
Tag: tagSuggestion.Tag,
Reason: tagSuggestion.Reason,
})
}
}
if len(filteredTags) > 0 {
return &v1pb.SuggestMemoTagsResponse{
SuggestedTags: filteredTags,
}, nil
}
}
}
// No AI configured - return empty suggestions
return &v1pb.SuggestMemoTagsResponse{
SuggestedTags: []*v1pb.TagSuggestion{},
}, nil
}

View File

@ -8,6 +8,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/plugin/ai"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
@ -308,6 +309,14 @@ func convertWorkspaceAISettingFromStore(setting *storepb.WorkspaceAISetting) *v1
TimeoutSeconds: setting.TimeoutSeconds,
}
if setting.TagRecommendation != nil {
result.TagRecommendation = &v1pb.WorkspaceSetting_TagRecommendationConfig{
Enabled: setting.TagRecommendation.Enabled,
SystemPrompt: setting.TagRecommendation.SystemPrompt,
RequestsPerMinute: setting.TagRecommendation.RequestsPerMinute,
}
}
return result
}
@ -330,6 +339,14 @@ func convertWorkspaceAISettingToStore(setting *v1pb.WorkspaceSetting_AiSetting)
TimeoutSeconds: setting.TimeoutSeconds,
}
if setting.TagRecommendation != nil {
result.TagRecommendation = &storepb.TagRecommendationConfig{
Enabled: setting.TagRecommendation.Enabled,
SystemPrompt: setting.TagRecommendation.SystemPrompt,
RequestsPerMinute: setting.TagRecommendation.RequestsPerMinute,
}
}
return result
}
@ -354,3 +371,10 @@ func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error)
ownerCache = convertUserFromStore(user)
return ownerCache, nil
}
// GetDefaultTagRecommendationPrompt returns the default system prompt for AI tag recommendations.
func (s *APIV1Service) GetDefaultTagRecommendationPrompt(ctx context.Context, _ *v1pb.GetDefaultTagRecommendationPromptRequest) (*v1pb.GetDefaultTagRecommendationPromptResponse, error) {
return &v1pb.GetDefaultTagRecommendationPromptResponse{
SystemPrompt: ai.GetDefaultSystemPrompt(),
}, nil
}

View File

@ -109,6 +109,28 @@ func RebuildMemoPayload(memo *store.Memo) error {
return nil
}
// ExtractTagsFromContent extracts tags from content string using the same logic as RebuildMemoPayload
// This function is exported for use in other packages (e.g., for tag recommendations)
func ExtractTagsFromContent(content string) []string {
nodes, err := parser.Parse(tokenizer.Tokenize(content))
if err != nil {
return []string{}
}
tags := []string{}
TraverseASTNodes(nodes, func(node ast.Node) {
switch n := node.(type) {
case *ast.Tag:
tag := n.Content
if !slices.Contains(tags, tag) {
tags = append(tags, tag)
}
}
})
return tags
}
func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
for _, node := range nodes {
fn(node)

View File

@ -238,6 +238,15 @@ func (s *Store) GetWorkspaceAISetting(ctx context.Context) (*storepb.WorkspaceAI
workspaceAISetting.TimeoutSeconds = defaultAITimeoutSeconds
}
// Set default tag recommendation config if not configured
if workspaceAISetting.TagRecommendation == nil {
workspaceAISetting.TagRecommendation = &storepb.TagRecommendationConfig{
Enabled: workspaceAISetting.EnableAi,
SystemPrompt: "",
RequestsPerMinute: 10,
}
}
s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_AI.String(), &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_AI,
Value: &storepb.WorkspaceSetting_AiSetting{AiSetting: workspaceAISetting},
@ -267,6 +276,11 @@ func loadAISettingFromEnv() *storepb.WorkspaceAISetting {
ApiKey: apiKey,
Model: model,
TimeoutSeconds: timeoutSeconds,
TagRecommendation: &storepb.TagRecommendationConfig{
Enabled: enableAI,
SystemPrompt: "",
RequestsPerMinute: 10,
},
}
}