- GemmaContext class that exposes Gemma functionality
- C API that uses GemmaContext
- C# interop class in GemmaInterop.cs
- New END_OF_TURN_ID in tokenizer.h, useful when dealing with instruction-tuned prompts

PiperOrigin-RevId: 730754638
This commit is contained in:
The gemma.cpp Authors 2025-02-24 23:59:12 -08:00 committed by Copybara-Service
parent b3b4b9f92f
commit 1f916b686b
8 changed files with 592 additions and 0 deletions

View File

@ -356,6 +356,44 @@ cc_library(
], ],
) )
cc_library(
name = "gemma_shared_lib",
srcs = [
"gemma/c_api.cc",
"gemma/context.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/c_api.h",
"gemma/context.h",
"gemma/gemma.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":allocator",
":app",
":basics",
":common",
":kv_cache",
":ops",
":threading",
":tokenizer",
":weights",
"//compression:compress",
"//compression:io",
"//compression:sfp",
"//paligemma:image",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
],
)
cc_library( cc_library(
name = "cross_entropy", name = "cross_entropy",
srcs = ["evals/cross_entropy.cc"], srcs = ["evals/cross_entropy.cc"],

View File

@ -110,6 +110,17 @@ set(SOURCES
util/threading.h util/threading.h
) )
# Add C API sources only when building DLL
if(BUILD_GEMMA_DLL)
list(APPEND SOURCES
gemma/context.h
gemma/context.cc
gemma/c_api.h
gemma/c_api.cc
)
message(STATUS "Including C API files for DLL build")
endif()
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
endif() endif()
@ -129,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>) target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib) install(TARGETS libgemma DESTINATION lib)
# Shared Library Target for C# interop
if(BUILD_GEMMA_DLL)
add_library(gemma_shared SHARED ${SOURCES})
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
set_target_properties(gemma_shared PROPERTIES
PREFIX ""
OUTPUT_NAME "gemma"
)
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(gemma_shared PUBLIC ./)
target_link_libraries(gemma_shared PRIVATE
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
)
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma_shared
PRIVATE
GEMMA_EXPORTS
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS gemma_shared DESTINATION lib)
install(FILES gemma/c_api.h DESTINATION include/gemma)
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
endif()
# Executable Target # Executable Target
add_executable(gemma gemma/run.cc) add_executable(gemma gemma/run.cc)

175
GemmaInterop.cs Normal file
View File

@ -0,0 +1,175 @@
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
namespace GemmaCpp
{
public class GemmaException : Exception
{
public GemmaException(string message) : base(message) { }
}
public class Gemma : IDisposable
{
private IntPtr _context;
private bool _disposed;
// Optional: Allow setting DLL path
public static string DllPath { get; set; } = "gemma.dll";
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern IntPtr LoadLibrary(string lpFileName);
static Gemma()
{
// Load DLL from specified path
if (LoadLibrary(DllPath) == IntPtr.Zero)
{
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
}
}
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr GemmaCreate(
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
int maxLength);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaDestroy(IntPtr context);
// Delegate type for token callbacks
public delegate bool TokenCallback(string token);
// Keep delegate alive for duration of calls
private GCHandle _callbackHandle;
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate bool GemmaTokenCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaGenerate(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output,
int maxLength,
GemmaTokenCallback callback,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaCountTokens(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
// Native callback delegate type
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void GemmaLogCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetLogCallback(
IntPtr context,
GemmaLogCallback callback,
IntPtr userData);
private GCHandle _logCallbackHandle;
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
{
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
if (_context == IntPtr.Zero)
{
throw new GemmaException("Failed to create Gemma context");
}
// optionally: set up logging
/*
GemmaLogCallback logCallback = (message, _) =>
{
#if UNITY_ENGINE
Debug.Log($"Gemma: {message}");
#else
Debug.WriteLine($"Gemma: {message}");
#endif
};
_logCallbackHandle = GCHandle.Alloc(logCallback);
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
*/
}
public int CountTokens(string prompt)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
int count = GemmaCountTokens(_context, prompt);
return count;
}
public string Generate(string prompt, int maxLength = 4096)
{
return Generate(prompt, null, maxLength);
}
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
var output = new StringBuilder(maxLength);
GemmaTokenCallback nativeCallback = null;
if (callback != null)
{
nativeCallback = (text, _) => callback(text);
_callbackHandle = GCHandle.Alloc(nativeCallback);
}
try
{
int length = GemmaGenerate(_context, prompt, output, maxLength,
nativeCallback, IntPtr.Zero);
if (length < 0)
throw new GemmaException("Generation failed");
return output.ToString();
}
finally
{
if (_callbackHandle.IsAllocated)
_callbackHandle.Free();
}
}
public void Dispose()
{
if (!_disposed)
{
if (_context != IntPtr.Zero)
{
GemmaDestroy(_context);
_context = IntPtr.Zero;
}
if (_logCallbackHandle.IsAllocated)
_logCallbackHandle.Free();
_disposed = true;
}
}
~Gemma()
{
Dispose();
}
}
}

54
gemma/c_api.cc Normal file
View File

@ -0,0 +1,54 @@
#ifndef GEMMA_EXPORTS
#define GEMMA_EXPORTS
#endif
#include "gemma/c_api.h"
// necessary as the C API and GemmaContext effectively wrap up and re-use the
// code for the Gemma executable
#include "util/app.h"
extern "C" {
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length) {
try {
// kludge
gcpp::AppArgs app_args;
app_args.Init();
app_args.max_packages = 1;
app_args.verbosity = 0;
app_args.spin = gcpp::Tristate::kFalse;
return new GemmaContext(tokenizer_path, model_type, weights_path,
weight_type, app_args, max_length);
} catch (...) {
return nullptr;
}
}
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
delete static_cast<gcpp::GemmaContext*>(ctx);
}
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data) {
if (!ctx) return -1;
return static_cast<gcpp::GemmaContext*>(ctx)->Generate(
prompt, output, max_length, callback, user_data);
}
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
if (!ctx || !text) return -1;
return static_cast<gcpp::GemmaContext*>(ctx)->CountTokens(text);
}
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data) {
if (!ctx) return;
static_cast<gcpp::GemmaContext*>(ctx)->SetLogCallback(callback, user_data);
}
}

62
gemma/c_api.h Normal file
View File

@ -0,0 +1,62 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_C_API_H_
#define THIRD_PARTY_GEMMA_C_API_H_
#include "gemma/context.h"
#ifdef _WIN32
#ifdef GEMMA_EXPORTS
#define GEMMA_API __declspec(dllexport)
#else
#define GEMMA_API __declspec(dllimport)
#endif
#else
#define GEMMA_API __attribute__((visibility("default")))
#endif
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __cplusplus
typedef gcpp::GemmaContext GemmaContext;
#else
typedef struct GemmaContext GemmaContext;
#endif
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length);
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data);
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data);
#ifdef __cplusplus
}
#endif
#endif // THIRD_PARTY_GEMMA_C_API_H_

130
gemma/context.cc Normal file
View File

@ -0,0 +1,130 @@
#include "gemma/context.h"
namespace gcpp {
void InitializeGemmaLibrary() {
AppArgs app;
app.Init();
app.max_packages = 1;
NestedPools pools = CreatePools(app);
Allocator::Init(pools.Topology());
}
// Initialize static members
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
void* GemmaContext::s_log_user_data = nullptr;
GemmaContext::GemmaContext(const char* tokenizer_path, const char* model_type,
const char* weights_path, const char* weight_type,
const AppArgs& app_args, int max_length)
: pools(CreatePools(app_args)) {
LoaderArgs loader(tokenizer_path, weights_path, model_type);
loader.weight_type_str = weight_type;
if (const char* error = loader.Validate()) {
HWY_ABORT("Invalid loader configuration: %s", error);
}
// Initialize cached args
inference_args.Init();
inference_args.max_generated_tokens = max_length;
inference_args.temperature = 0.7f;
inference_args.top_k = 1;
inference_args.deterministic = false;
Allocator::Init(pools.Topology());
model = AllocateGemma(loader, pools);
kv_cache =
std::make_unique<KVCache>(KVCache::Create(model->GetModelConfig(), 2048));
}
int GemmaContext::Generate(const char* prompt, char* output, int max_length,
GemmaTokenCallback callback, void* user_data) {
if (!model || !kv_cache || !prompt || !output || max_length <= 0) {
return -1;
}
try {
// Clear and reuse buffers
result_buffer.clear();
prompt_buffer.assign(prompt);
token_buffer.clear();
// The prompt is assumed to be already wrapped in the appropriate control
// tokens if necessary for an instruction tuned model, so we don't use
// WrapAndTokenize here
HWY_ASSERT(model->Tokenizer().Encode(prompt, &token_buffer));
// Both pre-trained and instruction-tuned require BOS as first token
if (token_buffer.at(0) != BOS_ID) {
token_buffer.insert(token_buffer.begin(), BOS_ID);
}
// Pass prompt_tokens to properly utilize KV cache for subsequent tokens
const size_t prompt_tokens = token_buffer.size();
size_t tokens_generated_this_turn = 0;
auto stream_token = [this, callback, user_data, prompt_tokens,
&tokens_generated_this_turn](int token, float) {
std::string token_text;
if (model->Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
// don't re-output the prompt tokens
if (tokens_generated_this_turn < prompt_tokens) {
++tokens_generated_this_turn;
return true;
}
// skip the end of turn token, this way we don't have to do string
// comparisons at the application level (is this a good idea?)
if (token == END_OF_TURN_ID) {
return false;
}
if (callback) {
if (!callback(token_text.c_str(), user_data)) {
return false;
}
}
result_buffer.append(token_text);
++tokens_generated_this_turn;
return true;
}
return false;
};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.use_spinning = Tristate::kFalse};
inference_args.max_generated_tokens = max_length;
inference_args.CopyTo(runtime_config);
TimingInfo timing_info = {.verbosity = 0};
hwy::Span<const int> testspan(token_buffer.data(), token_buffer.size());
// Pass prompt_tokens to properly utilize KV cache for subsequent tokens
model->Generate(runtime_config, testspan, prompt_tokens, 0, *kv_cache,
timing_info);
if (result_buffer.length() >= static_cast<size_t>(max_length)) {
return -1;
}
strcpy(output, result_buffer.c_str());
return static_cast<int>(result_buffer.length());
} catch (...) {
return -1;
}
}
int GemmaContext::CountTokens(const char* text) {
if (!model || !text) return -1;
try {
std::string text_str(text);
std::vector<int> tokens;
HWY_ASSERT(model->Tokenizer().Encode(text_str, &tokens));
return static_cast<int>(tokens.size());
} catch (...) {
return -1;
}
}
} // namespace gcpp

92
gemma/context.h Normal file
View File

@ -0,0 +1,92 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#include <memory>
#include <random>
#include <string>
#include <vector>
#ifdef _WIN32
#include <Windows.h>
#endif
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/threading.h"
namespace gcpp {
// Initialize global state needed by the library.
// Must be called before creating any Gemma instances.
void InitializeGemmaLibrary();
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext {
public:
GemmaContext(const char* tokenizer_path, const char* model_type,
const char* weights_path, const char* weight_type,
const AppArgs& app_args, int max_length = 2048);
// Returns length of generated text, or -1 on error
int Generate(const char* prompt, char* output, int max_length,
GemmaTokenCallback callback, void* user_data);
// Returns number of tokens in text, or -1 on error
int CountTokens(const char* text);
// Add new method to set logger
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
s_log_callback = callback;
s_log_user_data = user_data;
}
private:
NestedPools pools;
std::unique_ptr<Gemma> model;
std::unique_ptr<KVCache> kv_cache;
std::string prompt_buffer;
std::string result_buffer;
std::vector<int> token_buffer;
// Cached args
InferenceArgs inference_args;
AppArgs app_args;
std::mt19937 gen;
// Add static members for logging
static GemmaLogCallback s_log_callback;
static void* s_log_user_data;
// Use logging helper method to print messages into a managed callback if
// necessary
static void LogDebug(const char* message) {
if (s_log_callback) {
s_log_callback(message, s_log_user_data);
} else {
#ifdef _WIN32
OutputDebugStringA(message);
#endif
}
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_

View File

@ -31,6 +31,9 @@ namespace gcpp {
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;
constexpr int BOS_ID = 2; constexpr int BOS_ID = 2;
// The tokenizer's end of turn token id.
constexpr int END_OF_TURN_ID = 107;
class GemmaTokenizer { class GemmaTokenizer {
public: public:
GemmaTokenizer(); GemmaTokenizer();