diff --git a/plugin/ai/client.go b/plugin/ai/client.go new file mode 100644 index 000000000..f9418e9e7 --- /dev/null +++ b/plugin/ai/client.go @@ -0,0 +1,253 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" + + storepb "github.com/usememos/memos/proto/gen/store" +) + +// Common AI errors +var ( + ErrConfigIncomplete = errors.New("AI configuration incomplete - missing BaseURL, APIKey, or Model") + ErrEmptyRequest = errors.New("chat request cannot be empty") + ErrInvalidMessage = errors.New("message role must be 'system', 'user', or 'assistant'") + ErrEmptyContent = errors.New("message content cannot be empty") + ErrAPICallFailed = errors.New("AI API call failed") + ErrEmptyResponse = errors.New("received empty response from AI") + ErrNoChoices = errors.New("AI returned no response choices") +) + +// Config holds AI configuration +type Config struct { + Enabled bool + BaseURL string + APIKey string + Model string + TimeoutSeconds int +} + +// LoadConfigFromEnv loads AI configuration from environment variables +func LoadConfigFromEnv() *Config { + timeoutSeconds := 10 // default timeout + if timeoutStr := os.Getenv("AI_TIMEOUT_SECONDS"); timeoutStr != "" { + if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 { + timeoutSeconds = timeout + } + } + + config := &Config{ + BaseURL: os.Getenv("AI_BASE_URL"), + APIKey: os.Getenv("AI_API_KEY"), + Model: os.Getenv("AI_MODEL"), + TimeoutSeconds: timeoutSeconds, + } + + // Enable AI if all required fields are provided + config.Enabled = config.BaseURL != "" && config.APIKey != "" && config.Model != "" + + return config +} + +// LoadConfigFromDatabase loads AI configuration from database settings +func LoadConfigFromDatabase(aiSetting *storepb.WorkspaceAISetting) *Config { + if aiSetting == nil { + return &Config{Enabled: false} + } + + timeoutSeconds := int(aiSetting.TimeoutSeconds) + if timeoutSeconds <= 0 { + timeoutSeconds = 10 // default timeout + } + + return &Config{ + Enabled: aiSetting.EnableAi, + BaseURL: aiSetting.BaseUrl, + APIKey: aiSetting.ApiKey, + Model: aiSetting.Model, + TimeoutSeconds: timeoutSeconds, + } +} + +// MergeWithEnv merges database config with environment variables +// Environment variables take precedence if they are set +func (c *Config) MergeWithEnv() *Config { + envConfig := LoadConfigFromEnv() + + // Start with current config + merged := &Config{ + Enabled: c.Enabled, + BaseURL: c.BaseURL, + APIKey: c.APIKey, + Model: c.Model, + TimeoutSeconds: c.TimeoutSeconds, + } + + // Override with env vars if they are set + if envConfig.BaseURL != "" { + merged.BaseURL = envConfig.BaseURL + } + if envConfig.APIKey != "" { + merged.APIKey = envConfig.APIKey + } + if envConfig.Model != "" { + merged.Model = envConfig.Model + } + if os.Getenv("AI_TIMEOUT_SECONDS") != "" { + merged.TimeoutSeconds = envConfig.TimeoutSeconds + } + + // Enable if all required fields are present + merged.Enabled = merged.BaseURL != "" && merged.APIKey != "" && merged.Model != "" + + return merged +} + +// IsConfigured returns true if AI is properly configured +func (c *Config) IsConfigured() bool { + return c.Enabled && c.BaseURL != "" && c.APIKey != "" && c.Model != "" +} + +// Client wraps OpenAI client with convenience methods +type Client struct { + client openai.Client + config *Config +} + +// NewClient creates a new AI client +func NewClient(config *Config) (*Client, error) { + if config == nil { + return nil, fmt.Errorf("config cannot be nil") + } + + if !config.IsConfigured() { + return nil, ErrConfigIncomplete + } + + var client openai.Client + if config.BaseURL != "" && config.BaseURL != "https://api.openai.com/v1" { + client = openai.NewClient( + option.WithAPIKey(config.APIKey), + option.WithBaseURL(config.BaseURL), + ) + } else { + client = openai.NewClient( + option.WithAPIKey(config.APIKey), + ) + } + + return &Client{ + client: client, + config: config, + }, nil +} + +// ChatRequest represents a chat completion request +type ChatRequest struct { + Messages []Message + MaxTokens int + Temperature float64 + Timeout time.Duration +} + +// Message represents a chat message +type Message struct { + Role string // "system", "user", "assistant" + Content string +} + +// ChatResponse represents a chat completion response +type ChatResponse struct { + Content string +} + +// Chat performs a chat completion +func (c *Client) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + if req == nil { + return nil, ErrEmptyRequest + } + + if len(req.Messages) == 0 { + return nil, ErrEmptyRequest + } + + // Validate messages + for i, msg := range req.Messages { + if msg.Role != "system" && msg.Role != "user" && msg.Role != "assistant" { + return nil, fmt.Errorf("message %d: %w", i, ErrInvalidMessage) + } + if strings.TrimSpace(msg.Content) == "" { + return nil, fmt.Errorf("message %d: %w", i, ErrEmptyContent) + } + } + + // Set defaults + if req.MaxTokens == 0 { + req.MaxTokens = 8192 + } + if req.Temperature == 0 { + req.Temperature = 0.3 + } + if req.Timeout == 0 { + // Use timeout from config if available + if c.config.TimeoutSeconds > 0 { + req.Timeout = time.Duration(c.config.TimeoutSeconds) * time.Second + } else { + req.Timeout = 10 * time.Second + } + } + + model := c.config.Model + if model == "" { + model = "gpt-4o" // Default model + } + + // Convert messages + messages := make([]openai.ChatCompletionMessageParamUnion, 0, len(req.Messages)) + for _, msg := range req.Messages { + switch msg.Role { + case "system": + messages = append(messages, openai.SystemMessage(msg.Content)) + case "user": + messages = append(messages, openai.UserMessage(msg.Content)) + case "assistant": + messages = append(messages, openai.AssistantMessage(msg.Content)) + } + } + + // Create timeout context + timeoutCtx, cancel := context.WithTimeout(ctx, req.Timeout) + defer cancel() + + // Make API call + completion, err := c.client.Chat.Completions.New(timeoutCtx, openai.ChatCompletionNewParams{ + Messages: messages, + Model: model, + MaxTokens: openai.Int(int64(req.MaxTokens)), + Temperature: openai.Float(req.Temperature), + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrAPICallFailed, err) + } + + if len(completion.Choices) == 0 { + return nil, ErrNoChoices + } + + response := strings.TrimSpace(completion.Choices[0].Message.Content) + if response == "" { + return nil, ErrEmptyResponse + } + + return &ChatResponse{ + Content: response, + }, nil +} diff --git a/plugin/ai/client_test.go b/plugin/ai/client_test.go new file mode 100644 index 000000000..3a02815a9 --- /dev/null +++ b/plugin/ai/client_test.go @@ -0,0 +1,424 @@ +package ai + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadConfigFromEnv(t *testing.T) { + tests := []struct { + name string + envVars map[string]string + expected *Config + }{ + { + name: "all environment variables set", + envVars: map[string]string{ + "AI_BASE_URL": "https://api.openai.com/v1", + "AI_API_KEY": "sk-test123", + "AI_MODEL": "gpt-4o", + }, + expected: &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "gpt-4o", + }, + }, + { + name: "no environment variables set", + envVars: map[string]string{}, + expected: &Config{ + BaseURL: "", + APIKey: "", + Model: "", + }, + }, + { + name: "partial environment variables set", + envVars: map[string]string{ + "AI_BASE_URL": "https://custom.api.com/v1", + "AI_API_KEY": "sk-custom123", + }, + expected: &Config{ + BaseURL: "https://custom.api.com/v1", + APIKey: "sk-custom123", + Model: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original environment variables + origBaseURL := os.Getenv("AI_BASE_URL") + origAPIKey := os.Getenv("AI_API_KEY") + origModel := os.Getenv("AI_MODEL") + + // Clear existing environment variables + os.Unsetenv("AI_BASE_URL") + os.Unsetenv("AI_API_KEY") + os.Unsetenv("AI_MODEL") + + // Set test environment variables + for key, value := range tt.envVars { + os.Setenv(key, value) + } + + // Test configuration loading + config := LoadConfigFromEnv() + assert.Equal(t, tt.expected, config) + + // Restore original environment variables + os.Unsetenv("AI_BASE_URL") + os.Unsetenv("AI_API_KEY") + os.Unsetenv("AI_MODEL") + + if origBaseURL != "" { + os.Setenv("AI_BASE_URL", origBaseURL) + } + if origAPIKey != "" { + os.Setenv("AI_API_KEY", origAPIKey) + } + if origModel != "" { + os.Setenv("AI_MODEL", origModel) + } + }) + } +} + +func TestConfig_IsConfigured(t *testing.T) { + tests := []struct { + name string + config *Config + expected bool + }{ + { + name: "fully configured", + config: &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "gpt-4o", + }, + expected: true, + }, + { + name: "missing base URL", + config: &Config{ + BaseURL: "", + APIKey: "sk-test123", + Model: "gpt-4o", + }, + expected: false, + }, + { + name: "missing API key", + config: &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "", + Model: "gpt-4o", + }, + expected: false, + }, + { + name: "missing model", + config: &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "", + }, + expected: false, + }, + { + name: "all fields empty", + config: &Config{ + BaseURL: "", + APIKey: "", + Model: "", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.config.IsConfigured() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + config *Config + expectErr bool + }{ + { + name: "standard OpenAI configuration", + config: &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "gpt-4o", + }, + expectErr: false, + }, + { + name: "custom endpoint configuration", + config: &Config{ + BaseURL: "https://custom.api.com/v1", + APIKey: "sk-custom123", + Model: "gpt-3.5-turbo", + }, + expectErr: false, + }, + { + name: "incomplete configuration", + config: &Config{ + BaseURL: "", + APIKey: "sk-test123", + Model: "gpt-4o", + }, + expectErr: true, + }, + { + name: "nil configuration", + config: nil, + expectErr: true, + }, + { + name: "missing API key", + config: &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "", + Model: "gpt-4o", + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.config) + + if tt.expectErr { + assert.Error(t, err) + assert.Nil(t, client) + } else { + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, tt.config, client.config) + assert.NotNil(t, client.client) + } + }) + } +} + +func TestClient_Chat_RequestDefaults(t *testing.T) { + // This test verifies that default values are properly set + config := &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "gpt-4o", + } + + client, err := NewClient(config) + require.NoError(t, err) + + // Test with minimal request + req := &ChatRequest{ + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } + + // We can't actually call the API in tests without mocking, + // but we can verify the client was created successfully + assert.NotNil(t, client) + assert.Equal(t, config, client.config) + + // Verify default values would be set + assert.Equal(t, 0, req.MaxTokens) // Should become 8192 + assert.Equal(t, float64(0), req.Temperature) // Should become 0.3 + assert.Equal(t, time.Duration(0), req.Timeout) // Should become 10s +} + +func TestMessage_Roles(t *testing.T) { + tests := []struct { + name string + role string + valid bool + }{ + {"system role", "system", true}, + {"user role", "user", true}, + {"assistant role", "assistant", true}, + {"invalid role", "invalid", false}, + {"empty role", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := Message{ + Role: tt.role, + Content: "test content", + } + + // Valid roles are those that would be handled in the switch statement + validRoles := map[string]bool{ + "system": true, + "user": true, + "assistant": true, + } + + assert.Equal(t, tt.valid, validRoles[msg.Role]) + }) + } +} + +// Integration test helper - only runs with proper environment variables +func TestClient_Chat_Integration(t *testing.T) { + // Skip if not in integration test mode + if os.Getenv("AI_INTEGRATION_TEST") != "true" { + t.Skip("Skipping integration test - set AI_INTEGRATION_TEST=true to run") + } + + config := LoadConfigFromEnv() + if !config.IsConfigured() { + t.Skip("AI not configured - set AI_BASE_URL, AI_API_KEY, AI_MODEL environment variables") + } + + client, err := NewClient(config) + require.NoError(t, err) + ctx := context.Background() + + req := &ChatRequest{ + Messages: []Message{ + {Role: "user", Content: "Say 'Hello, World!' in exactly those words."}, + }, + MaxTokens: 50, + Temperature: 0.1, + Timeout: 30 * time.Second, + } + + resp, err := client.Chat(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.NotEmpty(t, resp.Content) + + t.Logf("AI Response: %s", resp.Content) +} + +func TestClient_Chat_Validation(t *testing.T) { + config := &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "gpt-4o", + } + + client, err := NewClient(config) + require.NoError(t, err) + ctx := context.Background() + + tests := []struct { + name string + request *ChatRequest + expectErr error + }{ + { + name: "nil request", + request: nil, + expectErr: ErrEmptyRequest, + }, + { + name: "empty messages", + request: &ChatRequest{ + Messages: []Message{}, + }, + expectErr: ErrEmptyRequest, + }, + { + name: "invalid message role", + request: &ChatRequest{ + Messages: []Message{ + {Role: "invalid", Content: "Hello"}, + }, + }, + expectErr: ErrInvalidMessage, + }, + { + name: "empty message content", + request: &ChatRequest{ + Messages: []Message{ + {Role: "user", Content: ""}, + }, + }, + expectErr: ErrEmptyContent, + }, + { + name: "whitespace-only message content", + request: &ChatRequest{ + Messages: []Message{ + {Role: "user", Content: " \n\t "}, + }, + }, + expectErr: ErrEmptyContent, + }, + { + name: "valid request", + request: &ChatRequest{ + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + }, + expectErr: nil, // This will fail with API call error in tests, but validation should pass + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := client.Chat(ctx, tt.request) + + if tt.expectErr != nil { + assert.Error(t, err) + assert.ErrorIs(t, err, tt.expectErr) + } else { + // For the valid request case, we expect an API call error since we don't have real credentials + // but the validation should pass, so we just check that it's not a validation error + if err != nil { + assert.NotErrorIs(t, err, ErrEmptyRequest) + assert.NotErrorIs(t, err, ErrInvalidMessage) + assert.NotErrorIs(t, err, ErrEmptyContent) + } + } + }) + } +} + +func TestClient_Chat_ErrorTypes(t *testing.T) { + config := &Config{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test123", + Model: "gpt-4o", + } + + client, err := NewClient(config) + require.NoError(t, err) + ctx := context.Background() + + // Test that we can identify specific error types + t.Run("can check for specific errors", func(t *testing.T) { + _, err := client.Chat(ctx, nil) + assert.ErrorIs(t, err, ErrEmptyRequest) + + _, err = client.Chat(ctx, &ChatRequest{ + Messages: []Message{ + {Role: "invalid", Content: "test"}, + }, + }) + assert.ErrorIs(t, err, ErrInvalidMessage) + }) +}