From 41321611fdbc321b8175922d565ca5df5d15d65a Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 20 Aug 2025 11:05:09 +0900 Subject: [PATCH] feature: add API server and client with Google protocol --- API_SERVER_README.md | 250 +++++++++++++++++++++ CMakeLists.txt | 31 +++ gemma/api_client.cc | 359 ++++++++++++++++++++++++++++++ gemma/api_server.cc | 514 +++++++++++++++++++++++++++++++++++++++++++ gemma/gemma_args.h | 34 +++ 5 files changed, 1188 insertions(+) create mode 100644 API_SERVER_README.md create mode 100644 gemma/api_client.cc create mode 100644 gemma/api_server.cc diff --git a/API_SERVER_README.md b/API_SERVER_README.md new file mode 100644 index 0000000..f7af504 --- /dev/null +++ b/API_SERVER_README.md @@ -0,0 +1,250 @@ +# Gemma.cpp API Server + +This is an HTTP API server for gemma.cpp that implements the Google API protocol, allowing you to interact with Gemma models through REST API endpoints compatible with the Google API format. + +## Features + +- **API-compatible**: Implements Google API endpoints +- **Unified client/server**: Single codebase supports both local and public API modes +- **Text generation**: Support for `generateContent` endpoint +- **Streaming support**: Server-Sent Events (SSE) for `streamGenerateContent` +- **Model management**: Support for `/v1beta/models` endpoint +- **Session management**: Maintains conversation context with KV cache +- **JSON responses**: All responses in Google API format +- **Error handling**: Proper HTTP status codes and error messages + +## Building + +The API server is built alongside the main gemma.cpp project: + +```bash +# Configure the build +cmake -B build -DCMAKE_BUILD_TYPE=Release + +# Build the API server and client +cmake --build build --target gemma_api_server gemma_api_client -j 8 +``` + +The binaries will be created at: +- `build/gemma_api_server` - Local API server +- `build/gemma_api_client` - Unified client for both local and public APIs + +## Usage + +### Starting the Local API Server + +```bash +./build/gemma_api_server \ + --tokenizer path/to/tokenizer.spm \ + --weights path/to/model.sbs \ + --port 8080 +``` + +**Required arguments:** +- `--tokenizer`: Path to the tokenizer file (`.spm`) +- `--weights`: Path to the model weights file (`.sbs`) + +**Optional arguments:** +- `--port`: Port to listen on (default: 8080) +- `--model`: Model name for API endpoints (default: gemma3-4b) + +### Using the Unified Client + +#### With Local Server +```bash +# Interactive chat with local server +./build/gemma_api_client --interactive 1 --host localhost --port 8080 + +# Single prompt with local server +./build/gemma_api_client --prompt "Hello, how are you?" +``` + +#### With Public Google API +```bash +# Set API key and use public API +export GOOGLE_API_KEY="your-api-key-here" +./build/gemma_api_client --interactive 1 + +# Or pass API key directly +./build/gemma_api_client --api_key "your-api-key" --interactive 1 +``` + +## API Endpoints + +The server implements Google API endpoints: + +### 1. Generate Content - `POST /v1beta/models/gemma3-4b:generateContent` + +Generate a response for given content (non-streaming). + +**Request:** +```json +{ + "contents": [ + { + "parts": [ + {"text": "Why is the sky blue?"} + ] + } + ], + "generationConfig": { + "temperature": 0.9, + "topK": 1, + "maxOutputTokens": 1024 + } +} +``` + +**Response:** +```json +{ + "candidates": [ + { + "content": { + "parts": [ + {"text": "The sky appears blue because..."} + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "promptFeedback": { + "safetyRatings": [] + }, + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 25, + "totalTokenCount": 30 + } +} +``` + +### 2. Stream Generate Content - `POST /v1beta/models/gemma3-4b:streamGenerateContent` + +Generate a response with Server-Sent Events (SSE) streaming. + +**Request:** Same as above + +**Response:** Stream of SSE events: +``` +data: {"candidates":[{"content":{"parts":[{"text":"The"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}} + +data: {"candidates":[{"content":{"parts":[{"text":" sky"}],"role":"model"},"index":0}],"promptFeedback":{"safetyRatings":[]}} + +data: [DONE] +``` + +### 3. List Models - `GET /v1beta/models` + +List available models. + +**Response:** +```json +{ + "models": [ + { + "name": "models/gemma3-4b", + "displayName": "Gemma3 4B", + "description": "Gemma3 4B model running locally" + } + ] +} +``` + +## Example Usage + +### Using curl with Local Server + +```bash +# Generate content (non-streaming) +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [{"parts": [{"text": "Hello, how are you?"}]}], + "generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024} + }' + +# Stream generate content (SSE) +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:streamGenerateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [{"parts": [{"text": "Tell me a story"}]}], + "generationConfig": {"temperature": 0.9, "topK": 1, "maxOutputTokens": 1024} + }' + +# List models +curl http://localhost:8080/v1beta/models +``` + +### Multi-turn Conversation with curl + +```bash +# First message +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"parts": [{"text": "Hi, my name is Alice"}]} + ] + }' + +# Follow-up message with conversation history +curl -X POST http://localhost:8080/v1beta/models/gemma3-4b:generateContent \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + {"parts": [{"text": "Hi, my name is Alice"}]}, + {"parts": [{"text": "Hello Alice! Nice to meet you."}]}, + {"parts": [{"text": "What is my name?"}]} + ] + }' +``` + +### Using Python + +```python +import requests + +# Generate content +response = requests.post('http://localhost:8080/v1beta/models/gemma3-4b:generateContent', + json={ + 'contents': [{'parts': [{'text': 'Explain quantum computing in simple terms'}]}], + 'generationConfig': { + 'temperature': 0.9, + 'topK': 1, + 'maxOutputTokens': 1024 + } + } +) + +result = response.json() +if 'candidates' in result and result['candidates']: + text = result['candidates'][0]['content']['parts'][0]['text'] + print(text) +``` + +## Configuration Options + +The Google API supports various generation configuration options: + +- **temperature**: Controls randomness (0.0 to 2.0, default: 1.0) +- **topK**: Top-K sampling parameter (default: 1) +- **maxOutputTokens**: Maximum number of tokens to generate (default: 8192) + +## Key Features + +- **Unified Implementation**: Same codebase handles both local server and public API +- **Session Management**: Maintains conversation context using KV cache +- **Streaming Support**: Real-time token generation via Server-Sent Events +- **Error Handling**: Comprehensive error responses and HTTP status codes +- **Memory Efficient**: Optimized token processing and caching + +## Compatibility + +This implementation is compatible with: +- Google API format and endpoints +- Standard HTTP clients (curl, browsers, Python requests, etc.) +- Server-Sent Events (SSE) for streaming responses +- JSON request/response format diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e4bfd0..8309840 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,28 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(json) +# Find OpenSSL for HTTPS support +find_package(OpenSSL) +if(OPENSSL_FOUND) + message(STATUS "OpenSSL found, enabling HTTPS support") + set(HTTPLIB_USE_OPENSSL_IF_AVAILABLE ON) +else() + message(STATUS "OpenSSL not found, HTTPS support disabled") + set(HTTPLIB_USE_OPENSSL_IF_AVAILABLE OFF) +endif() + +# HTTP library for API server +FetchContent_Declare(httplib GIT_REPOSITORY https://github.com/yhirose/cpp-httplib.git GIT_TAG v0.18.1 EXCLUDE_FROM_ALL) +FetchContent_MakeAvailable(httplib) + +# Create interface target for httplib (header-only library) +add_library(httplib_interface INTERFACE) +target_include_directories(httplib_interface INTERFACE ${httplib_SOURCE_DIR}) +if(OPENSSL_FOUND) + target_link_libraries(httplib_interface INTERFACE OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(httplib_interface INTERFACE CPPHTTPLIB_OPENSSL_SUPPORT) +endif() + set(BENCHMARK_ENABLE_TESTING OFF) set(BENCHMARK_ENABLE_GTEST_TESTS OFF) @@ -232,3 +254,12 @@ endif() # GEMMA_ENABLE_TESTS add_executable(migrate_weights io/migrate_weights.cc) target_link_libraries(migrate_weights libgemma hwy hwy_contrib) + + +# API server with SSE support +add_executable(gemma_api_server gemma/api_server.cc) +target_link_libraries(gemma_api_server libgemma hwy hwy_contrib nlohmann_json::nlohmann_json httplib_interface) + +# API client for testing +add_executable(gemma_api_client gemma/api_client.cc) +target_link_libraries(gemma_api_client libgemma hwy hwy_contrib nlohmann_json::nlohmann_json httplib_interface) diff --git a/gemma/api_client.cc b/gemma/api_client.cc new file mode 100644 index 0000000..1f64d96 --- /dev/null +++ b/gemma/api_client.cc @@ -0,0 +1,359 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test client for API server + +#include +#include +#include +#include +#include + +#include "httplib.h" +#include "nlohmann/json.hpp" +#include "gemma/gemma_args.h" + +using json = nlohmann::json; + +// ANSI color codes +const std::string RESET = "\033[0m"; +const std::string BOLD = "\033[1m"; +const std::string GREEN = "\033[32m"; +const std::string BLUE = "\033[34m"; +const std::string CYAN = "\033[36m"; +const std::string YELLOW = "\033[33m"; +const std::string RED = "\033[31m"; + +class APIClient { +public: + APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b") + : host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) { + if (use_https_) { + ssl_client_ = std::make_unique(host, port); + ssl_client_->set_read_timeout(60, 0); + ssl_client_->set_write_timeout(60, 0); + ssl_client_->enable_server_certificate_verification(false); + } else { + client_ = std::make_unique(host, port); + client_->set_read_timeout(60, 0); + client_->set_write_timeout(60, 0); + } + } + + // Unified request processing for both public and local APIs + json ProcessRequest(const json& request, bool stream = true) { + bool is_public_api = !api_key_.empty(); + + std::string endpoint; + if (is_public_api) { + endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" + : "/v1beta/models/gemini-2.0-flash:generateContent"; + } else { + endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" + : "/v1beta/models/" + model_ + ":generateContent"; + } + + // Only show verbose output in non-interactive mode + if (!interactive_mode_) { + std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; + std::cout << "Request: " << request.dump(2) << std::endl; + } + + if (stream) { + return ProcessStreamingRequest(request, endpoint); + } else { + return ProcessNonStreamingRequest(request, endpoint); + } + } + + void TestGenerateContent(const std::string& prompt, bool stream = true) { + json request = CreateAPIRequest(prompt); + json response = ProcessRequest(request, stream); + + if (response.contains("error")) { + std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + } + } + + void TestListModels() { + std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; + + httplib::Headers headers; + if (!api_key_.empty()) { + headers.emplace("X-goog-api-key", api_key_); + } + auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers); + + if (res && res->status == 200) { + json response = json::parse(res->body); + std::cout << GREEN << "✅ Available models:" << RESET << std::endl; + std::cout << response.dump(2) << std::endl; + } else { + std::cerr << RED << "❌ Request failed" << RESET << std::endl; + } + } + + void InteractiveChat() { + std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl; + std::cout << "Type ':gemma %q' to end.\n" << std::endl; + + interactive_mode_ = true; + json messages; + + while (true) { + std::cout << BOLD << BLUE << "You: " << RESET; + std::string input; + std::getline(std::cin, input); + + if (input == ":gemma %q") { + std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl; + break; + } + + if (input.empty()) continue; + + // Add user message with proper role + json user_message = {{"parts", {{{"text", input}}}}}; + if (!api_key_.empty()) { + user_message["role"] = "user"; + } + messages.push_back(user_message); + + // Create request using unified logic + json request = CreateAPIRequest("", messages); + + std::cout << BOLD << GREEN << "Assistant: " << RESET; + + // Use unified processing - streaming for real-time output + json response = ProcessRequest(request, true); + + if (response.contains("candidates") && !response["candidates"].empty()) { + auto& candidate = response["candidates"][0]; + if (candidate.contains("content") && candidate["content"].contains("parts")) { + for (const auto& part : candidate["content"]["parts"]) { + if (part.contains("text")) { + std::string assistant_response = part["text"].get(); + + // For streaming, the response is already displayed in real-time + // Just add to message history for context + json assistant_message = {{"parts", {{{"text", assistant_response}}}}}; + if (!api_key_.empty()) { + assistant_message["role"] = "model"; + } + messages.push_back(assistant_message); + } + } + } + } else if (response.contains("error")) { + std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + } + + std::cout << std::endl; + } + } + +private: + json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) { + json request = { + {"generationConfig", { + {"temperature", 0.9}, + {"topK", 1}, + {"maxOutputTokens", 1024} + }} + }; + + if (messages.empty()) { + // Single prompt + json user_message = {{"parts", {{{"text", prompt}}}}}; + if (!api_key_.empty()) { + user_message["role"] = "user"; + } + request["contents"] = json::array({user_message}); + } else { + // Use provided message history + request["contents"] = messages; + } + + return request; + } + + json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) { + httplib::Headers headers = {{"Content-Type", "application/json"}}; + if (!api_key_.empty()) { + headers.emplace("X-goog-api-key", api_key_); + } + + auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json") + : client_->Post(endpoint, headers, request.dump(), "application/json"); + + if (res && res->status == 200) { + json response = json::parse(res->body); + if (!interactive_mode_) { + std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl; + std::cout << response.dump(2) << std::endl; + } + return response; + } else { + json error_response = { + {"error", { + {"message", "Request failed"}, + {"status", res ? res->status : -1} + }} + }; + if (res && !res->body.empty()) { + error_response["error"]["details"] = res->body; + } + std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl; + return error_response; + } + } + + json ProcessStreamingRequest(const json& request, const std::string& endpoint) { + std::string accumulated_response; + + // Use same SSE logic for both public and local APIs + httplib::Request req; + req.method = "POST"; + req.path = endpoint; + req.set_header("Content-Type", "application/json"); + if (!api_key_.empty()) { + req.set_header("X-goog-api-key", api_key_); + } + req.body = request.dump(); + + req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool { + std::string chunk(data, data_length); + std::istringstream stream(chunk); + std::string line; + + while (std::getline(stream, line)) { + if (line.substr(0, 6) == "data: ") { + std::string event_data = line.substr(6); + + if (event_data == "[DONE]") { + if (!interactive_mode_) { + std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl; + } + } else { + try { + json event = json::parse(event_data); + if (event.contains("candidates") && !event["candidates"].empty()) { + auto& candidate = event["candidates"][0]; + if (candidate.contains("content") && candidate["content"].contains("parts")) { + for (const auto& part : candidate["content"]["parts"]) { + if (part.contains("text")) { + std::string text = part["text"].get(); + std::cout << text << std::flush; + accumulated_response += text; + } + } + } + } + } catch (const json::exception& e) { + // Skip parse errors + } + } + } + } + return true; + }; + + httplib::Response res; + httplib::Error error; + bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error); + + if (res.status == 200 && !accumulated_response.empty()) { + return json{ + {"candidates", {{ + {"content", { + {"parts", {{{"text", accumulated_response}}}} + }} + }}} + }; + } else { + json error_response = { + {"error", { + {"message", "Streaming request failed"}, + {"status", res.status} + }} + }; + if (!res.body.empty()) { + error_response["error"]["details"] = res.body; + } + std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl; + return error_response; + } + } + +private: + std::unique_ptr client_; + std::unique_ptr ssl_client_; + std::string host_; + int port_; + std::string api_key_; + std::string model_; + bool use_https_; + bool interactive_mode_; +}; + +int main(int argc, char* argv[]) { + gcpp::ClientArgs client_args(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + std::cout << "\nAPI Client for gemma.cpp\n"; + std::cout << "========================\n\n"; + client_args.Help(); + std::cout << std::endl; + std::cout << "Environment Variables:" << std::endl; + std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl; + return 0; + } + + // Check for GOOGLE_API_KEY environment variable + const char* env_api_key = std::getenv("GOOGLE_API_KEY"); + if (env_api_key != nullptr && strlen(env_api_key) > 0) { + client_args.api_key = env_api_key; + client_args.host = "generativelanguage.googleapis.com"; + client_args.port = 443; + } + + // Handle API key override + if (!client_args.api_key.empty()) { + client_args.host = "generativelanguage.googleapis.com"; + client_args.port = 443; + } + + std::cout << BOLD << YELLOW << "🚀 Testing API Server at " + << client_args.host << ":" << client_args.port << RESET << std::endl; + + try { + APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model); + + if (client_args.interactive) { + client.InteractiveChat(); + } else { + client.TestListModels(); + client.TestGenerateContent(client_args.prompt, true); + } + + } catch (const std::exception& e) { + std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl; + std::cerr << "Make sure the API server is running:" << std::endl; + std::cerr << " ./build/gemma_api_server --tokenizer --weights " << std::endl; + return 1; + } + + return 0; +} diff --git a/gemma/api_server.cc b/gemma/api_server.cc new file mode 100644 index 0000000..70b3115 --- /dev/null +++ b/gemma/api_server.cc @@ -0,0 +1,514 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// HTTP API server for gemma.cpp with SSE support + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// HTTP server library +#undef CPPHTTPLIB_OPENSSL_SUPPORT +#undef CPPHTTPLIB_ZLIB_SUPPORT +#include "httplib.h" + +// JSON library +#include "nlohmann/json.hpp" + +#include "compression/types.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/tokenizer.h" +#include "ops/matmul.h" +#include "util/args.h" +#include "hwy/base.h" +#include "hwy/profiler.h" + +using json = nlohmann::json; + +namespace gcpp { + +static std::atomic server_running{true}; + +// Server state holding model and KV caches +struct ServerState { + std::unique_ptr gemma; + MatMulEnv* env; + ThreadingContext* ctx; + + // Session-based KV cache storage + struct Session { + std::unique_ptr kv_cache; + size_t abs_pos = 0; + std::chrono::steady_clock::time_point last_access; + }; + + std::unordered_map sessions; + std::mutex sessions_mutex; + std::mutex inference_mutex; + + // Cleanup old sessions after 30 minutes of inactivity + void CleanupOldSessions() { + std::lock_guard lock(sessions_mutex); + auto now = std::chrono::steady_clock::now(); + for (auto it = sessions.begin(); it != sessions.end();) { + if (now - it->second.last_access > std::chrono::minutes(30)) { + it = sessions.erase(it); + } else { + ++it; + } + } + } + + // Get or create session with KV cache + Session& GetOrCreateSession(const std::string& session_id) { + std::lock_guard lock(sessions_mutex); + auto& session = sessions[session_id]; + if (!session.kv_cache) { + session.kv_cache = std::make_unique(gemma->Config(), InferenceArgs(), env->ctx.allocator); + } + session.last_access = std::chrono::steady_clock::now(); + return session; + } +}; + +// Generate a unique session ID +std::string GenerateSessionId() { + static std::atomic counter{0}; + std::stringstream ss; + ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count() + << "_" << counter.fetch_add(1); + return ss.str(); +} + +// Wraps messages with start_of_turn markers - handles both with and without roles +std::string WrapMessagesWithTurnMarkers(const json& contents) { + std::string prompt; + + for (const auto& content : contents) { + if (content.contains("parts")) { + // Check if role is specified (public API format) or not (local format) + std::string role = content.value("role", ""); + + for (const auto& part : content["parts"]) { + if (part.contains("text")) { + std::string text = part["text"]; + + if (role == "user") { + prompt += "user\n" + text + "\nmodel\n"; + } else if (role == "model") { + prompt += text + "\n"; + } else if (role.empty()) { + // Local format without roles - for now, treat as user input + prompt += "user\n" + text + "\nmodel\n"; + } + } + } + } + } + + return prompt; +} + +// Parse generation config +RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) { + RuntimeConfig config; + config.gen = &gen; + config.verbosity = 0; + + // Set defaults matching public API + config.temperature = 1.0f; + config.top_k = 1; + config.max_generated_tokens = 8192; + + if (request.contains("generationConfig")) { + auto& gen_config = request["generationConfig"]; + + if (gen_config.contains("temperature")) { + config.temperature = gen_config["temperature"].get(); + } + if (gen_config.contains("topK")) { + config.top_k = gen_config["topK"].get(); + } + if (gen_config.contains("maxOutputTokens")) { + config.max_generated_tokens = gen_config["maxOutputTokens"].get(); + } + } + + return config; +} + +// Unified response formatter - creates consistent format regardless of request type +json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { + json response = { + {"candidates", {{ + {"content", { + {"parts", {{{"text", text}}}}, + {"role", "model"} + }}, + {"index", 0} + }}}, + {"promptFeedback", {{"safetyRatings", json::array()}}} + }; + + // Only add finishReason for non-streaming chunks + if (!is_streaming_chunk) { + response["candidates"][0]["finishReason"] = "STOP"; + } + + return response; +} + +// Handle generateContent endpoint (non-streaming) +void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { + try { + json request = json::parse(req.body); + + // Get or create session + std::string session_id = request.value("sessionId", GenerateSessionId()); + auto& session = state.GetOrCreateSession(session_id); + + // Extract prompt from API format + std::string prompt; + if (request.contains("contents")) { + prompt = WrapMessagesWithTurnMarkers(request["contents"]); + } else { + res.status = 400; + res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + return; + } + + // Lock for inference + std::lock_guard lock(state.inference_mutex); + + // Set up runtime config + std::mt19937 gen; + RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); + + // Collect full response + std::string full_response; + runtime_config.stream_token = [&full_response](int token, float) { + // Skip EOS token + return true; + }; + + // Tokenize prompt + std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), + state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, + session.abs_pos, + prompt); + + // Run inference with KV cache + TimingInfo timing_info = {.verbosity = 0}; + size_t prefix_end = 0; + + // Temporarily redirect output to capture response + std::stringstream output; + runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { + // Skip prompt tokens + if (session.abs_pos < tokens.size()) { + session.abs_pos++; + return true; + } + + session.abs_pos++; + + // Check for EOS + if (state.gemma->Config().IsEOS(token)) { + return true; + } + + // Decode token + std::string token_text; + state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); + output << token_text; + + return true; + }; + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, + *session.kv_cache, *state.env, timing_info); + + // Create response + json response = CreateAPIResponse(output.str(), false); + response["usageMetadata"] = { + {"promptTokenCount", tokens.size()}, + {"candidatesTokenCount", session.abs_pos - tokens.size()}, + {"totalTokenCount", session.abs_pos} + }; + + res.set_content(response.dump(), "application/json"); + + } catch (const json::exception& e) { + res.status = 400; + res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), + "application/json"); + } catch (const std::exception& e) { + res.status = 500; + res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(), + "application/json"); + } +} + +// Handle streamGenerateContent endpoint with SSE) +void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { + try { + json request = json::parse(req.body); + + // Get or create session + std::string session_id = request.value("sessionId", GenerateSessionId()); + auto& session = state.GetOrCreateSession(session_id); + + // Extract prompt from API format + std::string prompt; + if (request.contains("contents")) { + prompt = WrapMessagesWithTurnMarkers(request["contents"]); + } else { + res.status = 400; + res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + return; + } + + // Set up SSE headers + res.set_header("Content-Type", "text/event-stream"); + res.set_header("Cache-Control", "no-cache"); + res.set_header("Connection", "keep-alive"); + res.set_header("X-Session-Id", session_id); + + // Set up chunked content provider for SSE + res.set_chunked_content_provider( + "text/event-stream", + [&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) { + try { + // Lock for inference + std::lock_guard lock(state.inference_mutex); + auto& session = state.GetOrCreateSession(session_id); + + // Set up runtime config + std::mt19937 gen; + RuntimeConfig runtime_config = ParseGenerationConfig(request, gen); + + // Tokenize prompt + std::vector tokens = WrapAndTokenize(state.gemma->Tokenizer(), + state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, + session.abs_pos, + prompt); + + // Stream token callback + std::string accumulated_text; + auto stream_token = [&](int token, float) { + // Skip prompt tokens + if (session.abs_pos < tokens.size()) { + session.abs_pos++; + return true; + } + + session.abs_pos++; + + // Check for EOS + if (state.gemma->Config().IsEOS(token)) { + return true; + } + + // Decode token + std::string token_text; + state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); + accumulated_text += token_text; + + // Send SSE event using unified formatter + json event = CreateAPIResponse(token_text, true); + + std::string sse_data = "data: " + event.dump() + "\n\n"; + sink.write(sse_data.data(), sse_data.size()); + + return true; + }; + + runtime_config.stream_token = stream_token; + + // Run inference with KV cache + TimingInfo timing_info = {.verbosity = 0}; + size_t prefix_end = 0; + + state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, + *session.kv_cache, *state.env, timing_info); + + // Send final event using unified formatter + json final_event = CreateAPIResponse("", false); + final_event["usageMetadata"] = { + {"promptTokenCount", tokens.size()}, + {"candidatesTokenCount", session.abs_pos - tokens.size()}, + {"totalTokenCount", session.abs_pos} + }; + + std::string final_sse = "data: " + final_event.dump() + "\n\n"; + sink.write(final_sse.data(), final_sse.size()); + + // Send done event + sink.write("data: [DONE]\n\n", 15); + + // Ensure all data is sent + sink.done(); + + return false; // End streaming + + } catch (const std::exception& e) { + json error_event = {{"error", {{"message", e.what()}}}}; + std::string error_sse = "data: " + error_event.dump() + "\n\n"; + sink.write(error_sse.data(), error_sse.size()); + return false; + } + } + ); + + } catch (const json::exception& e) { + res.status = 400; + res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(), + "application/json"); + } +} + +// Handle models list endpoint +void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { + json response = { + {"models", {{ + {"name", "models/" + inference.model}, + {"version", "001"}, + {"displayName", inference.model}, + {"description", inference.model + " model running locally"}, + {"inputTokenLimit", 8192}, + {"outputTokenLimit", 8192}, + {"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, + {"temperature", 1.0}, + {"topK", 1} + }}} + }; + + res.set_content(response.dump(), "application/json"); +} + +// void HandleShutdown(int signal) { +// std::cerr << "\nShutting down server..." << std::endl; +// server_running = false; +// } + +void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference) { + std::cerr << "Loading model..." << std::endl; + + // Initialize model + ThreadingContext ctx(threading); + MatMulEnv env(ctx); + + ServerState state; + state.gemma = std::make_unique(loader, inference, ctx); + state.env = &env; + state.ctx = &ctx; + + httplib::Server server; + + // Set up routes + server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { + res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); + }); + + // API endpoints + server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { + HandleListModels(state, inference, req, res); + }); + + std::string model_endpoint = "/v1beta/models/" + inference.model; + server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentNonStreaming(state, req, res); + }); + + server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentStreaming(state, req, res); + }); + + // Periodic cleanup of old sessions + std::thread cleanup_thread([&state]() { + while (server_running) { + std::this_thread::sleep_for(std::chrono::minutes(5)); + state.CleanupOldSessions(); + } + }); + + std::cerr << "Starting API server on port " << inference.port << std::endl; + std::cerr << "Model loaded successfully" << std::endl; + std::cerr << "Endpoints:" << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; + std::cerr << " GET /v1beta/models" << std::endl; + + if (!server.listen("0.0.0.0", inference.port)) { + std::cerr << "Failed to start server on port " << inference.port << std::endl; + } + + cleanup_thread.join(); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + gcpp::InternalInit(); + + gcpp::LoaderArgs loader(argc, argv); + gcpp::ThreadingArgs threading(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + std::cerr << "\n\nAPI server for gemma.cpp\n"; + std::cout << "========================\n\n"; + std::cerr << "Usage: " << argv[0] << " --weights --tokenizer [options]\n"; + std::cerr << "\nOptions:\n"; + std::cerr << " --port PORT Server port (default: 8080)\n"; + std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n"; + std::cerr << "\n"; + std::cerr << "\n*Model Loading Arguments*\n\n"; + loader.Help(); + std::cerr << "\n*Threading Arguments*\n\n"; + threading.Help(); + std::cerr << "\n*Inference Arguments*\n\n"; + inference.Help(); + std::cerr << "\n"; + return 0; + } + + // Arguments are now handled by InferenceArgs + + // // Set up signal handler + // signal(SIGINT, gcpp::HandleShutdown); + // signal(SIGTERM, gcpp::HandleShutdown); + + gcpp::RunServer(loader, threading, inference); + + return 0; +} diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 161e9a5..469ba2a 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -183,6 +183,8 @@ struct InferenceArgs : public ArgsBase { bool multiturn; Path image_file; + int port; // Server port + std::string model; // Model name for API endpoints std::string prompt; // Bypasses std::getline // For prompts longer than the Linux terminal's 4K line edit buffer. Path prompt_file; @@ -218,6 +220,10 @@ struct InferenceArgs : public ArgsBase { "resets every turn)"); visitor(image_file, "image_file", Path(), "Image file to load."); + // Since it is not used in the CLI version, the print_verbosity is set higher than others. + visitor(port, "port", 8080, "Server port (default: 8080)", 3); + visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3); + visitor(prompt, "prompt", std::string(""), "Initial prompt for non-interactive mode. When specified, " "generates a response and exits.", @@ -258,6 +264,34 @@ struct InferenceArgs : public ArgsBase { } }; +struct ClientArgs : public ArgsBase { + ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + ClientArgs() { Init(); }; + + std::string host; + int port; + std::string api_key; + std::string model; + std::string prompt; + bool interactive; + + template + void ForEach(const Visitor& visitor) { + visitor(host, "host", std::string("localhost"), + "Server host (default: localhost)"); + visitor(port, "port", 8080, + "Server port (default: 8080)"); + visitor(api_key, "api_key", std::string(""), + "Use public API with key (changes host to generativelanguage.googleapis.com:443)"); + visitor(model, "model", std::string("gemma3-4b"), + "Model name to use (default: gemma3-4b)"); + visitor(prompt, "prompt", std::string("Hello! How are you?"), + "Prompt for generation (default: 'Hello! How are you?')"); + visitor(interactive, "interactive", false, + "Start interactive chat mode (0 = no, 1 = yes)"); + } +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_