diff --git a/plugin/ai/client.go b/plugin/ai/client.go index 236a8203d..b884adb57 100644 --- a/plugin/ai/client.go +++ b/plugin/ai/client.go @@ -2,8 +2,6 @@ package ai import ( "context" - "errors" - "fmt" "os" "strconv" "strings" @@ -11,11 +9,12 @@ import ( "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" + "github.com/pkg/errors" storepb "github.com/usememos/memos/proto/gen/store" ) -// Common AI errors +// Common AI errors. var ( ErrConfigIncomplete = errors.New("AI configuration incomplete - missing BaseURL, APIKey, or Model") ErrEmptyRequest = errors.New("chat request cannot be empty") @@ -125,7 +124,7 @@ type Client struct { // NewClient creates a new AI client func NewClient(config *Config) (*Client, error) { if config == nil { - return nil, fmt.Errorf("config cannot be nil") + return nil, errors.New("config cannot be nil") } if !config.IsConfigured() { @@ -182,10 +181,10 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, err // Validate messages for i, msg := range req.Messages { if msg.Role != "system" && msg.Role != "user" && msg.Role != "assistant" { - return nil, fmt.Errorf("message %d: %w", i, ErrInvalidMessage) + return nil, errors.Wrapf(ErrInvalidMessage, "message %d", i) } if strings.TrimSpace(msg.Content) == "" { - return nil, fmt.Errorf("message %d: %w", i, ErrEmptyContent) + return nil, errors.Wrapf(ErrEmptyContent, "message %d", i) } } @@ -235,7 +234,7 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, err Temperature: openai.Float(req.Temperature), }) if err != nil { - return nil, fmt.Errorf("%w: %v", ErrAPICallFailed, err) + return nil, errors.Wrapf(ErrAPICallFailed, "%v", err) } if len(completion.Choices) == 0 { diff --git a/plugin/ai/client_test.go b/plugin/ai/client_test.go index 3a02815a9..676215e74 100644 --- a/plugin/ai/client_test.go +++ b/plugin/ai/client_test.go @@ -245,6 +245,11 @@ func TestClient_Chat_RequestDefaults(t *testing.T) { assert.Equal(t, 0, req.MaxTokens) // Should become 8192 assert.Equal(t, float64(0), req.Temperature) // Should become 0.3 assert.Equal(t, time.Duration(0), req.Timeout) // Should become 10s + + // Verify the Messages field is properly structured + assert.Len(t, req.Messages, 1) + assert.Equal(t, "user", req.Messages[0].Role) + assert.Equal(t, "Hello", req.Messages[0].Content) } func TestMessage_Roles(t *testing.T) { @@ -275,6 +280,7 @@ func TestMessage_Roles(t *testing.T) { } assert.Equal(t, tt.valid, validRoles[msg.Role]) + assert.Equal(t, "test content", msg.Content) }) } } diff --git a/plugin/ai/tag_suggestion.go b/plugin/ai/tag_suggestion.go index c8a12f90b..d533773a0 100644 --- a/plugin/ai/tag_suggestion.go +++ b/plugin/ai/tag_suggestion.go @@ -8,6 +8,8 @@ import ( "strings" "text/template" "time" + + "github.com/pkg/errors" ) // defaultSystemPrompt contains the core instructions for tag recommendation. @@ -38,10 +40,10 @@ const userMessageTemplate = `{{if .ExistingTags}}Existing Tags: {{.ExistingTags} // TagSuggestionRequest represents a tag suggestion request type TagSuggestionRequest struct { - Content string // The memo content to analyze - UserTags []string // User's frequently used tags (optional) - ExistingTags []string // Tags already in the memo (optional) - SystemPrompt string // Custom system prompt (optional, uses default if empty) + Content string // The memo content to analyze + UserTags []string // User's frequently used tags (optional) + ExistingTags []string // Tags already in the memo (optional) + SystemPrompt string // Custom system prompt (optional, uses default if empty) } // TagSuggestion represents a single tag suggestion with reason @@ -64,11 +66,11 @@ func GetDefaultSystemPrompt() string { func (c *Client) SuggestTags(ctx context.Context, req *TagSuggestionRequest) (*TagSuggestionResponse, error) { // Validate request if req == nil { - return nil, fmt.Errorf("request cannot be nil") + return nil, errors.New("request cannot be nil") } if strings.TrimSpace(req.Content) == "" { - return nil, fmt.Errorf("content cannot be empty") + return nil, errors.New("content cannot be empty") } // Prepare user tags context @@ -131,7 +133,7 @@ func (c *Client) SuggestTags(ctx context.Context, req *TagSuggestionRequest) (*T } // parseTagResponse parses AI response for [tag](reason) patterns -func (c *Client) parseTagResponse(responseText string) []TagSuggestion { +func (_ *Client) parseTagResponse(responseText string) []TagSuggestion { tags := make([]TagSuggestion, 0) // Match [tag](reason) format using regex across response diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 0fc6e5248..460990c8b 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -1010,7 +1010,7 @@ func (s *APIV1Service) SuggestMemoTags(ctx context.Context, request *v1pb.Sugges if err != nil { return nil, status.Errorf(codes.Internal, "failed to get AI setting: %v", err) } - + // Check if tag recommendation is enabled if !aiSetting.EnableAi || aiSetting.TagRecommendation == nil || !aiSetting.TagRecommendation.Enabled { return &v1pb.SuggestMemoTagsResponse{ @@ -1032,13 +1032,13 @@ func (s *APIV1Service) SuggestMemoTags(ctx context.Context, request *v1pb.Sugges if !tagRateLimit.checkRateLimit(currentUser.ID, maxRequestsPerMinute) { return nil, status.Errorf(codes.ResourceExhausted, "标签推荐请求频率超限,每分钟最多 %d 次", maxRequestsPerMinute) } - + aiConfig := ai.LoadConfigFromDatabase(aiSetting) if !aiConfig.IsConfigured() { // Fallback to environment variables if database config is not complete aiConfig = aiConfig.MergeWithEnv() } - + if aiConfig.IsConfigured() { aiClient, err := ai.NewClient(aiConfig) if err != nil { diff --git a/server/router/api/v1/workspace_service.go b/server/router/api/v1/workspace_service.go index ee0bf3a28..34dbd28a2 100644 --- a/server/router/api/v1/workspace_service.go +++ b/server/router/api/v1/workspace_service.go @@ -373,7 +373,7 @@ func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) } // GetDefaultTagRecommendationPrompt returns the default system prompt for AI tag recommendations. -func (s *APIV1Service) GetDefaultTagRecommendationPrompt(ctx context.Context, _ *v1pb.GetDefaultTagRecommendationPromptRequest) (*v1pb.GetDefaultTagRecommendationPromptResponse, error) { +func (_ *APIV1Service) GetDefaultTagRecommendationPrompt(ctx context.Context, _ *v1pb.GetDefaultTagRecommendationPromptRequest) (*v1pb.GetDefaultTagRecommendationPromptResponse, error) { return &v1pb.GetDefaultTagRecommendationPromptResponse{ SystemPrompt: ai.GetDefaultSystemPrompt(), }, nil diff --git a/server/runner/memopayload/runner.go b/server/runner/memopayload/runner.go index 7f04d05df..75a5d7817 100644 --- a/server/runner/memopayload/runner.go +++ b/server/runner/memopayload/runner.go @@ -110,7 +110,7 @@ func RebuildMemoPayload(memo *store.Memo) error { } // ExtractTagsFromContent extracts tags from content string using the same logic as RebuildMemoPayload -// This function is exported for use in other packages (e.g., for tag recommendations) +// This function is exported for use in other packages (e.g., for tag recommendations). func ExtractTagsFromContent(content string) []string { nodes, err := parser.Parse(tokenizer.Tokenize(content)) if err != nil { @@ -119,8 +119,7 @@ func ExtractTagsFromContent(content string) []string { tags := []string{} TraverseASTNodes(nodes, func(node ast.Node) { - switch n := node.(type) { - case *ast.Tag: + if n, ok := node.(*ast.Tag); ok { tag := n.Content if !slices.Contains(tags, tag) { tags = append(tags, tag) diff --git a/store/workspace_setting.go b/store/workspace_setting.go index b48c3a981..c5360898d 100644 --- a/store/workspace_setting.go +++ b/store/workspace_setting.go @@ -213,10 +213,10 @@ func (s *Store) GetWorkspaceStorageSetting(ctx context.Context) (*storepb.Worksp } const ( - defaultAITimeoutSeconds = int32(15) + defaultAITimeoutSeconds = int32(15) defaultAITagRecommandationEnabled = false - defaultAITagRecommandationPrompt = "" - defaultAITagRecommandationRPM = int32(10) + defaultAITagRecommandationPrompt = "" + defaultAITagRecommandationRPM = int32(10) ) func (s *Store) GetWorkspaceAISetting(ctx context.Context) (*storepb.WorkspaceAISetting, error) { @@ -256,7 +256,7 @@ func (s *Store) GetWorkspaceAISetting(ctx context.Context) (*storepb.WorkspaceAI return workspaceAISetting, nil } -// loadAISettingFromEnv loads AI configuration from environment variables +// loadAISettingFromEnv loads AI configuration from environment variables. func loadAISettingFromEnv() *storepb.WorkspaceAISetting { timeoutSeconds := defaultAITimeoutSeconds if timeoutStr := os.Getenv("AI_TIMEOUT_SECONDS"); timeoutStr != "" {