mirror of https://github.com/google/gemma.cpp.git
515 lines
17 KiB
C++
515 lines
17 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// HTTP API server for gemma.cpp with SSE support
|
|
|
|
#include <stdio.h>
|
|
#include <signal.h>
|
|
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <random>
|
|
#include <string>
|
|
#include <string_view>
|
|
#include <vector>
|
|
#include <thread>
|
|
#include <atomic>
|
|
#include <chrono>
|
|
#include <sstream>
|
|
#include <iomanip>
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
|
|
// HTTP server library
|
|
#undef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
#undef CPPHTTPLIB_ZLIB_SUPPORT
|
|
#include "httplib.h"
|
|
|
|
// JSON library
|
|
#include "nlohmann/json.hpp"
|
|
|
|
#include "compression/types.h"
|
|
#include "gemma/gemma.h"
|
|
#include "gemma/gemma_args.h"
|
|
#include "gemma/tokenizer.h"
|
|
#include "ops/matmul.h"
|
|
#include "util/args.h"
|
|
#include "hwy/base.h"
|
|
#include "hwy/profiler.h"
|
|
|
|
using json = nlohmann::json;
|
|
|
|
namespace gcpp {
|
|
|
|
static std::atomic<bool> server_running{true};
|
|
|
|
// Server state holding model and KV caches
|
|
struct ServerState {
|
|
std::unique_ptr<Gemma> gemma;
|
|
MatMulEnv* env;
|
|
ThreadingContext* ctx;
|
|
|
|
// Session-based KV cache storage
|
|
struct Session {
|
|
std::unique_ptr<KVCache> kv_cache;
|
|
size_t abs_pos = 0;
|
|
std::chrono::steady_clock::time_point last_access;
|
|
};
|
|
|
|
std::unordered_map<std::string, Session> sessions;
|
|
std::mutex sessions_mutex;
|
|
std::mutex inference_mutex;
|
|
|
|
// Cleanup old sessions after 30 minutes of inactivity
|
|
void CleanupOldSessions() {
|
|
std::lock_guard<std::mutex> lock(sessions_mutex);
|
|
auto now = std::chrono::steady_clock::now();
|
|
for (auto it = sessions.begin(); it != sessions.end();) {
|
|
if (now - it->second.last_access > std::chrono::minutes(30)) {
|
|
it = sessions.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get or create session with KV cache
|
|
Session& GetOrCreateSession(const std::string& session_id) {
|
|
std::lock_guard<std::mutex> lock(sessions_mutex);
|
|
auto& session = sessions[session_id];
|
|
if (!session.kv_cache) {
|
|
session.kv_cache = std::make_unique<KVCache>(gemma->Config(), InferenceArgs(), env->ctx.allocator);
|
|
}
|
|
session.last_access = std::chrono::steady_clock::now();
|
|
return session;
|
|
}
|
|
};
|
|
|
|
// Generate a unique session ID
|
|
std::string GenerateSessionId() {
|
|
static std::atomic<uint64_t> counter{0};
|
|
std::stringstream ss;
|
|
ss << "session_" << std::hex << std::chrono::steady_clock::now().time_since_epoch().count()
|
|
<< "_" << counter.fetch_add(1);
|
|
return ss.str();
|
|
}
|
|
|
|
// Wraps messages with start_of_turn markers - handles both with and without roles
|
|
std::string WrapMessagesWithTurnMarkers(const json& contents) {
|
|
std::string prompt;
|
|
|
|
for (const auto& content : contents) {
|
|
if (content.contains("parts")) {
|
|
// Check if role is specified (public API format) or not (local format)
|
|
std::string role = content.value("role", "");
|
|
|
|
for (const auto& part : content["parts"]) {
|
|
if (part.contains("text")) {
|
|
std::string text = part["text"];
|
|
|
|
if (role == "user") {
|
|
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
|
} else if (role == "model") {
|
|
prompt += text + "\n";
|
|
} else if (role.empty()) {
|
|
// Local format without roles - for now, treat as user input
|
|
prompt += "<start_of_turn>user\n" + text + "\n<start_of_turn>model\n";
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return prompt;
|
|
}
|
|
|
|
// Parse generation config
|
|
RuntimeConfig ParseGenerationConfig(const json& request, std::mt19937& gen) {
|
|
RuntimeConfig config;
|
|
config.gen = &gen;
|
|
config.verbosity = 0;
|
|
|
|
// Set defaults matching public API
|
|
config.temperature = 1.0f;
|
|
config.top_k = 1;
|
|
config.max_generated_tokens = 8192;
|
|
|
|
if (request.contains("generationConfig")) {
|
|
auto& gen_config = request["generationConfig"];
|
|
|
|
if (gen_config.contains("temperature")) {
|
|
config.temperature = gen_config["temperature"].get<float>();
|
|
}
|
|
if (gen_config.contains("topK")) {
|
|
config.top_k = gen_config["topK"].get<int>();
|
|
}
|
|
if (gen_config.contains("maxOutputTokens")) {
|
|
config.max_generated_tokens = gen_config["maxOutputTokens"].get<size_t>();
|
|
}
|
|
}
|
|
|
|
return config;
|
|
}
|
|
|
|
// Unified response formatter - creates consistent format regardless of request type
|
|
json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) {
|
|
json response = {
|
|
{"candidates", {{
|
|
{"content", {
|
|
{"parts", {{{"text", text}}}},
|
|
{"role", "model"}
|
|
}},
|
|
{"index", 0}
|
|
}}},
|
|
{"promptFeedback", {{"safetyRatings", json::array()}}}
|
|
};
|
|
|
|
// Only add finishReason for non-streaming chunks
|
|
if (!is_streaming_chunk) {
|
|
response["candidates"][0]["finishReason"] = "STOP";
|
|
}
|
|
|
|
return response;
|
|
}
|
|
|
|
// Handle generateContent endpoint (non-streaming)
|
|
void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
|
try {
|
|
json request = json::parse(req.body);
|
|
|
|
// Get or create session
|
|
std::string session_id = request.value("sessionId", GenerateSessionId());
|
|
auto& session = state.GetOrCreateSession(session_id);
|
|
|
|
// Extract prompt from API format
|
|
std::string prompt;
|
|
if (request.contains("contents")) {
|
|
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
|
} else {
|
|
res.status = 400;
|
|
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
|
return;
|
|
}
|
|
|
|
// Lock for inference
|
|
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
|
|
|
// Set up runtime config
|
|
std::mt19937 gen;
|
|
RuntimeConfig runtime_config = ParseGenerationConfig(request, gen);
|
|
|
|
// Collect full response
|
|
std::string full_response;
|
|
runtime_config.stream_token = [&full_response](int token, float) {
|
|
// Skip EOS token
|
|
return true;
|
|
};
|
|
|
|
// Tokenize prompt
|
|
std::vector<int> tokens = WrapAndTokenize(state.gemma->Tokenizer(),
|
|
state.gemma->ChatTemplate(),
|
|
state.gemma->Config().wrapping,
|
|
session.abs_pos,
|
|
prompt);
|
|
|
|
// Run inference with KV cache
|
|
TimingInfo timing_info = {.verbosity = 0};
|
|
size_t prefix_end = 0;
|
|
|
|
// Temporarily redirect output to capture response
|
|
std::stringstream output;
|
|
runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) {
|
|
// Skip prompt tokens
|
|
if (session.abs_pos < tokens.size()) {
|
|
session.abs_pos++;
|
|
return true;
|
|
}
|
|
|
|
session.abs_pos++;
|
|
|
|
// Check for EOS
|
|
if (state.gemma->Config().IsEOS(token)) {
|
|
return true;
|
|
}
|
|
|
|
// Decode token
|
|
std::string token_text;
|
|
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
|
output << token_text;
|
|
|
|
return true;
|
|
};
|
|
|
|
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
|
*session.kv_cache, *state.env, timing_info);
|
|
|
|
// Create response
|
|
json response = CreateAPIResponse(output.str(), false);
|
|
response["usageMetadata"] = {
|
|
{"promptTokenCount", tokens.size()},
|
|
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
|
{"totalTokenCount", session.abs_pos}
|
|
};
|
|
|
|
res.set_content(response.dump(), "application/json");
|
|
|
|
} catch (const json::exception& e) {
|
|
res.status = 400;
|
|
res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(),
|
|
"application/json");
|
|
} catch (const std::exception& e) {
|
|
res.status = 500;
|
|
res.set_content(json{{"error", {{"message", std::string("Server error: ") + e.what()}}}}.dump(),
|
|
"application/json");
|
|
}
|
|
}
|
|
|
|
// Handle streamGenerateContent endpoint with SSE)
|
|
void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) {
|
|
try {
|
|
json request = json::parse(req.body);
|
|
|
|
// Get or create session
|
|
std::string session_id = request.value("sessionId", GenerateSessionId());
|
|
auto& session = state.GetOrCreateSession(session_id);
|
|
|
|
// Extract prompt from API format
|
|
std::string prompt;
|
|
if (request.contains("contents")) {
|
|
prompt = WrapMessagesWithTurnMarkers(request["contents"]);
|
|
} else {
|
|
res.status = 400;
|
|
res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json");
|
|
return;
|
|
}
|
|
|
|
// Set up SSE headers
|
|
res.set_header("Content-Type", "text/event-stream");
|
|
res.set_header("Cache-Control", "no-cache");
|
|
res.set_header("Connection", "keep-alive");
|
|
res.set_header("X-Session-Id", session_id);
|
|
|
|
// Set up chunked content provider for SSE
|
|
res.set_chunked_content_provider(
|
|
"text/event-stream",
|
|
[&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) {
|
|
try {
|
|
// Lock for inference
|
|
std::lock_guard<std::mutex> lock(state.inference_mutex);
|
|
auto& session = state.GetOrCreateSession(session_id);
|
|
|
|
// Set up runtime config
|
|
std::mt19937 gen;
|
|
RuntimeConfig runtime_config = ParseGenerationConfig(request, gen);
|
|
|
|
// Tokenize prompt
|
|
std::vector<int> tokens = WrapAndTokenize(state.gemma->Tokenizer(),
|
|
state.gemma->ChatTemplate(),
|
|
state.gemma->Config().wrapping,
|
|
session.abs_pos,
|
|
prompt);
|
|
|
|
// Stream token callback
|
|
std::string accumulated_text;
|
|
auto stream_token = [&](int token, float) {
|
|
// Skip prompt tokens
|
|
if (session.abs_pos < tokens.size()) {
|
|
session.abs_pos++;
|
|
return true;
|
|
}
|
|
|
|
session.abs_pos++;
|
|
|
|
// Check for EOS
|
|
if (state.gemma->Config().IsEOS(token)) {
|
|
return true;
|
|
}
|
|
|
|
// Decode token
|
|
std::string token_text;
|
|
state.gemma->Tokenizer().Decode(std::vector<int>{token}, &token_text);
|
|
accumulated_text += token_text;
|
|
|
|
// Send SSE event using unified formatter
|
|
json event = CreateAPIResponse(token_text, true);
|
|
|
|
std::string sse_data = "data: " + event.dump() + "\n\n";
|
|
sink.write(sse_data.data(), sse_data.size());
|
|
|
|
return true;
|
|
};
|
|
|
|
runtime_config.stream_token = stream_token;
|
|
|
|
// Run inference with KV cache
|
|
TimingInfo timing_info = {.verbosity = 0};
|
|
size_t prefix_end = 0;
|
|
|
|
state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
|
|
*session.kv_cache, *state.env, timing_info);
|
|
|
|
// Send final event using unified formatter
|
|
json final_event = CreateAPIResponse("", false);
|
|
final_event["usageMetadata"] = {
|
|
{"promptTokenCount", tokens.size()},
|
|
{"candidatesTokenCount", session.abs_pos - tokens.size()},
|
|
{"totalTokenCount", session.abs_pos}
|
|
};
|
|
|
|
std::string final_sse = "data: " + final_event.dump() + "\n\n";
|
|
sink.write(final_sse.data(), final_sse.size());
|
|
|
|
// Send done event
|
|
sink.write("data: [DONE]\n\n", 15);
|
|
|
|
// Ensure all data is sent
|
|
sink.done();
|
|
|
|
return false; // End streaming
|
|
|
|
} catch (const std::exception& e) {
|
|
json error_event = {{"error", {{"message", e.what()}}}};
|
|
std::string error_sse = "data: " + error_event.dump() + "\n\n";
|
|
sink.write(error_sse.data(), error_sse.size());
|
|
return false;
|
|
}
|
|
}
|
|
);
|
|
|
|
} catch (const json::exception& e) {
|
|
res.status = 400;
|
|
res.set_content(json{{"error", {{"message", std::string("JSON parsing error: ") + e.what()}}}}.dump(),
|
|
"application/json");
|
|
}
|
|
}
|
|
|
|
// Handle models list endpoint
|
|
void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) {
|
|
json response = {
|
|
{"models", {{
|
|
{"name", "models/" + inference.model},
|
|
{"version", "001"},
|
|
{"displayName", inference.model},
|
|
{"description", inference.model + " model running locally"},
|
|
{"inputTokenLimit", 8192},
|
|
{"outputTokenLimit", 8192},
|
|
{"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})},
|
|
{"temperature", 1.0},
|
|
{"topK", 1}
|
|
}}}
|
|
};
|
|
|
|
res.set_content(response.dump(), "application/json");
|
|
}
|
|
|
|
// void HandleShutdown(int signal) {
|
|
// std::cerr << "\nShutting down server..." << std::endl;
|
|
// server_running = false;
|
|
// }
|
|
|
|
void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|
const InferenceArgs& inference) {
|
|
std::cerr << "Loading model..." << std::endl;
|
|
|
|
// Initialize model
|
|
ThreadingContext ctx(threading);
|
|
MatMulEnv env(ctx);
|
|
|
|
ServerState state;
|
|
state.gemma = std::make_unique<Gemma>(loader, inference, ctx);
|
|
state.env = &env;
|
|
state.ctx = &ctx;
|
|
|
|
httplib::Server server;
|
|
|
|
// Set up routes
|
|
server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) {
|
|
res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain");
|
|
});
|
|
|
|
// API endpoints
|
|
server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) {
|
|
HandleListModels(state, inference, req, res);
|
|
});
|
|
|
|
std::string model_endpoint = "/v1beta/models/" + inference.model;
|
|
server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
|
HandleGenerateContentNonStreaming(state, req, res);
|
|
});
|
|
|
|
server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) {
|
|
HandleGenerateContentStreaming(state, req, res);
|
|
});
|
|
|
|
// Periodic cleanup of old sessions
|
|
std::thread cleanup_thread([&state]() {
|
|
while (server_running) {
|
|
std::this_thread::sleep_for(std::chrono::minutes(5));
|
|
state.CleanupOldSessions();
|
|
}
|
|
});
|
|
|
|
std::cerr << "Starting API server on port " << inference.port << std::endl;
|
|
std::cerr << "Model loaded successfully" << std::endl;
|
|
std::cerr << "Endpoints:" << std::endl;
|
|
std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl;
|
|
std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl;
|
|
std::cerr << " GET /v1beta/models" << std::endl;
|
|
|
|
if (!server.listen("0.0.0.0", inference.port)) {
|
|
std::cerr << "Failed to start server on port " << inference.port << std::endl;
|
|
}
|
|
|
|
cleanup_thread.join();
|
|
}
|
|
|
|
} // namespace gcpp
|
|
|
|
int main(int argc, char** argv) {
|
|
gcpp::InternalInit();
|
|
|
|
gcpp::LoaderArgs loader(argc, argv);
|
|
gcpp::ThreadingArgs threading(argc, argv);
|
|
gcpp::InferenceArgs inference(argc, argv);
|
|
|
|
if (gcpp::HasHelp(argc, argv)) {
|
|
std::cerr << "\n\nAPI server for gemma.cpp\n";
|
|
std::cout << "========================\n\n";
|
|
std::cerr << "Usage: " << argv[0] << " --weights <path> --tokenizer <path> [options]\n";
|
|
std::cerr << "\nOptions:\n";
|
|
std::cerr << " --port PORT Server port (default: 8080)\n";
|
|
std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n";
|
|
std::cerr << "\n";
|
|
std::cerr << "\n*Model Loading Arguments*\n\n";
|
|
loader.Help();
|
|
std::cerr << "\n*Threading Arguments*\n\n";
|
|
threading.Help();
|
|
std::cerr << "\n*Inference Arguments*\n\n";
|
|
inference.Help();
|
|
std::cerr << "\n";
|
|
return 0;
|
|
}
|
|
|
|
// Arguments are now handled by InferenceArgs
|
|
|
|
// // Set up signal handler
|
|
// signal(SIGINT, gcpp::HandleShutdown);
|
|
// signal(SIGTERM, gcpp::HandleShutdown);
|
|
|
|
gcpp::RunServer(loader, threading, inference);
|
|
|
|
return 0;
|
|
}
|