// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // HTTP API server for gemma.cpp with SSE support #include #include #include #include #include #include #include #include #include #include // NOLINT #include #include // HTTP server library #undef CPPHTTPLIB_OPENSSL_SUPPORT #undef CPPHTTPLIB_ZLIB_SUPPORT #include "httplib.h" // JSON library #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "ops/matmul.h" #include "util/args.h" #include "hwy/base.h" #include "nlohmann/json.hpp" using json = nlohmann::json; namespace gcpp { static std::atomic server_running{true}; // Server state holding model and KV caches struct ServerState { std::unique_ptr gemma; MatMulEnv* env; ThreadingContext* ctx; // Session-based KV cache storage struct Session { std::unique_ptr kv_cache; size_t abs_pos = 0; std::chrono::steady_clock::time_point last_access; }; std::unordered_map sessions; std::mutex sessions_mutex; std::mutex inference_mutex; // Cleanup old sessions after 30 minutes of inactivity void CleanupOldSessions() { std::lock_guard lock(sessions_mutex); auto now = std::chrono::steady_clock::now(); for (auto it = sessions.begin(); it != sessions.end();) { if (now - it->second.last_access > std::chrono::minutes(30)) { it = sessions.erase(it); } else { ++it; } } } // Get or create session with KV cache Session& GetOrCreateSession(const std::string& session_id) { std::lock_guard lock(sessions_mutex); auto& session = sessions[session_id]; if (!session.kv_cache) { session.kv_cache = std::make_unique( gemma->Config(), InferenceArgs(), env->ctx.allocator); } session.last_access = std::chrono::steady_clock::now(); return session; } }; // Generate a unique session ID std::string GenerateSessionId() { static std::atomic counter{0}; std::stringstream ss; ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count() << "_" << counter.fetch_add(1); return ss.str(); } // Wraps messages with start_of_turn markers - handles both with and without // roles std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string prompt; for (const auto& content : contents) { if (content.contains("parts")) { // Check if role is specified (public API format) or not (local format) std::string role = content.value("role", ""); for (const auto& part : content["parts"]) { if (part.contains("text")) { std::string text = part["text"]; if (role == "user") { prompt += "user\n" + text + "\nmodel\n"; } else if (role == "model") { prompt += text + "\n"; } else if (role.empty()) { // Local format without roles - for now, treat as user input prompt += "user\n" + text + "\nmodel\n"; } } } } } return prompt; } // Parse generation config RuntimeConfig ParseGenerationConfig(const json& request) { RuntimeConfig config; config.verbosity = 0; // Set defaults matching public API config.temperature = 1.0f; config.top_k = 1; config.max_generated_tokens = 8192; if (request.contains("generationConfig")) { auto& gen_config = request["generationConfig"]; if (gen_config.contains("temperature")) { config.temperature = gen_config["temperature"].get(); } if (gen_config.contains("topK")) { config.top_k = gen_config["topK"].get(); } if (gen_config.contains("maxOutputTokens")) { config.max_generated_tokens = gen_config["maxOutputTokens"].get(); } } return config; } // Unified response formatter - creates consistent format regardless of request // type json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { json response = { {"candidates", {{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}}, {"index", 0}}}}, {"promptFeedback", {{"safetyRatings", json::array()}}}}; // Only add finishReason for non-streaming chunks if (!is_streaming_chunk) { response["candidates"][0]["finishReason"] = "STOP"; } return response; } // Handle generateContent endpoint (non-streaming) void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { try { json request = json::parse(req.body); // Get or create session std::string session_id = request.value("sessionId", GenerateSessionId()); auto& session = state.GetOrCreateSession(session_id); // Extract prompt from API format std::string prompt; if (request.contains("contents")) { prompt = WrapMessagesWithTurnMarkers(request["contents"]); } else { res.status = 400; res.set_content( json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); return; } // Lock for inference std::lock_guard lock(state.inference_mutex); // Set up runtime config RuntimeConfig runtime_config = ParseGenerationConfig(request); runtime_config.stream_token = [](int token, float) { return true; }; // Tokenize prompt std::vector tokens = WrapAndTokenize( state.gemma->Tokenizer(), state.gemma->ChatTemplate(), state.gemma->Config().wrapping, session.abs_pos, prompt); // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; // Temporarily redirect output to capture response std::stringstream output; runtime_config.stream_token = [&output, &state, &session, &tokens]( int token, float) { // Skip prompt tokens if (session.abs_pos < tokens.size()) { session.abs_pos++; return true; } session.abs_pos++; // Check for EOS if (state.gemma->Config().IsEOS(token)) { return true; } // Decode token std::string token_text; state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); output << token_text; return true; }; state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, *session.kv_cache, *state.env, timing_info); // Create response json response = CreateAPIResponse(output.str(), false); response["usageMetadata"] = { {"promptTokenCount", tokens.size()}, {"candidatesTokenCount", session.abs_pos - tokens.size()}, {"totalTokenCount", session.abs_pos} }; res.set_content(response.dump(), "application/json"); } catch (const json::exception& e) { res.status = 400; res.set_content( json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}} .dump(), "application/json"); } catch (const std::exception& e) { res.status = 500; res.set_content( json{{"error", {{"message", std::string("Server error: ") + e.what()}}}} .dump(), "application/json"); } } // Handle streamGenerateContent endpoint with SSE) void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { try { json request = json::parse(req.body); // Get or create session std::string session_id = request.value("sessionId", GenerateSessionId()); auto& session = state.GetOrCreateSession(session_id); // Extract prompt from API format std::string prompt; if (request.contains("contents")) { prompt = WrapMessagesWithTurnMarkers(request["contents"]); } else { res.status = 400; res.set_content( json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); return; } // Set up SSE headers res.set_header("Content-Type", "text/event-stream"); res.set_header("Cache-Control", "no-cache"); res.set_header("Connection", "keep-alive"); res.set_header("X-Session-Id", session_id); // Set up chunked content provider for SSE res.set_chunked_content_provider( "text/event-stream", [&state, request, prompt, session_id]( size_t offset, httplib::DataSink& sink) { try { // Lock for inference std::lock_guard lock(state.inference_mutex); auto& session = state.GetOrCreateSession(session_id); // Set up runtime config RuntimeConfig runtime_config = ParseGenerationConfig(request); // Tokenize prompt std::vector tokens = WrapAndTokenize( state.gemma->Tokenizer(), state.gemma->ChatTemplate(), state.gemma->Config().wrapping, session.abs_pos, prompt); // Stream token callback std::string accumulated_text; auto stream_token = [&](int token, float) { // Skip prompt tokens if (session.abs_pos < tokens.size()) { session.abs_pos++; return true; } session.abs_pos++; // Check for EOS if (state.gemma->Config().IsEOS(token)) { return true; } // Decode token std::string token_text; state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); accumulated_text += token_text; // Send SSE event using unified formatter json event = CreateAPIResponse(token_text, true); std::string sse_data = "data: " + event.dump() + "\n\n"; sink.write(sse_data.data(), sse_data.size()); return true; }; runtime_config.stream_token = stream_token; // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, *session.kv_cache, *state.env, timing_info); // Send final event using unified formatter json final_event = CreateAPIResponse("", false); final_event["usageMetadata"] = { {"promptTokenCount", tokens.size()}, {"candidatesTokenCount", session.abs_pos - tokens.size()}, {"totalTokenCount", session.abs_pos}}; std::string final_sse = "data: " + final_event.dump() + "\n\n"; sink.write(final_sse.data(), final_sse.size()); // Send done event sink.write("data: [DONE]\n\n", 15); // Ensure all data is sent sink.done(); return false; // End streaming } catch (const std::exception& e) { json error_event = {{"error", {{"message", e.what()}}}}; std::string error_sse = "data: " + error_event.dump() + "\n\n"; sink.write(error_sse.data(), error_sse.size()); return false; } }); } catch (const json::exception& e) { res.status = 400; res.set_content( json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}} .dump(), "application/json"); } } // Handle models list endpoint void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { json response = { {"models", {{{"name", "models/" + inference.model}, {"version", "001"}, {"displayName", inference.model}, {"description", inference.model + " model running locally"}, {"inputTokenLimit", 8192}, {"outputTokenLimit", 8192}, {"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, {"temperature", 1.0}, {"topK", 1}}}}}; res.set_content(response.dump(), "application/json"); } // void HandleShutdown(int signal) { // std::cerr << "\nShutting down server..." << std::endl; // server_running = false; // } void RunServer(const GemmaArgs& args) { std::cerr << "Loading model..." << std::endl; // Initialize model ThreadingContext ctx(args.threading); MatMulEnv env(ctx); ServerState state; state.gemma = std::make_unique(args, ctx); state.env = &env; state.ctx = &ctx; const InferenceArgs& inference = args.inference; httplib::Server server; // Set up routes server.Get( "/", [&inference](const httplib::Request&, httplib::Response& res) { res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); }); // API endpoints server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { HandleListModels(state, inference, req, res); }); std::string model_endpoint = "/v1beta/models/" + inference.model; server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { HandleGenerateContentNonStreaming(state, req, res); }); server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { HandleGenerateContentStreaming(state, req, res); }); // Periodic cleanup of old sessions std::thread cleanup_thread([&state]() { while (server_running) { std::this_thread::sleep_for(std::chrono::minutes(5)); state.CleanupOldSessions(); } }); std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Endpoints:" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; std::cerr << " GET /v1beta/models" << std::endl; if (!server.listen("0.0.0.0", inference.port)) { std::cerr << "Failed to start server on port " << inference.port << std::endl; } cleanup_thread.join(); } } // namespace gcpp int main(int argc, char** argv) { gcpp::InternalInit(); gcpp::ConsumedArgs consumed(argc, argv); gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { fprintf( stderr, "\n\nAPI server for gemma.cpp\n" "========================\n\n" " --port PORT Server port (default: 8080)\n" " --model MODEL Model name for endpoints (default: gemma3-4b)\n"); args.Help(); return 0; } consumed.AbortIfUnconsumed(); // // Set up signal handler // signal(SIGINT, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown); gcpp::RunServer(args); return 0; }