Add C API and C# interop files

This change adds a basic C API that allows access to Gemma functionality from other programming languages. The functionality is exposed via a shared library (DLL on Windows), with C++ interfaces and a basic C# interop wrapper included.

To build the DLL, use the `windows-dll` preset, which includes the C and C++ sources as follows:
```
cmake --preset windows-dll
cmake --build --config Release --preset windows-dll -j 4
```
This should generate a `gemma.dll` in `<build-dir>/Release`.

To build for non-Windows, the appropriate C++ DLL linking will need to be done to generate a shared library for the target OS.

PiperOrigin-RevId: 750246272
This commit is contained in:
The gemma.cpp Authors 2025-04-22 10:35:12 -07:00 committed by Copybara-Service
parent f20da328de
commit ba10c88a94
11 changed files with 1327 additions and 8 deletions

View File

@ -428,6 +428,40 @@ cc_library(
], ],
) )
cc_library(
name = "gemma_shared_lib",
srcs = [
"gemma/bindings/c_api.cc",
"gemma/bindings/context.cc",
],
hdrs = [
"gemma/bindings/c_api.h",
"gemma/bindings/context.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":allocator",
":basics",
":benchmark_helper",
":common",
":gemma_args",
":gemma_lib",
":kv_cache",
":mat",
":ops",
":threading",
":threading_context",
":tokenizer",
":weights",
"//compression:shared",
"//paligemma:image",
"@highway//:hwy",
],
)
cc_library( cc_library(
name = "cross_entropy", name = "cross_entropy",
srcs = ["evals/cross_entropy.cc"], srcs = ["evals/cross_entropy.cc"],
@ -465,6 +499,7 @@ cc_library(
":gemma_lib", ":gemma_lib",
":ops", ":ops",
":threading_context", ":threading_context",
":tokenizer",
"@google_benchmark//:benchmark", "@google_benchmark//:benchmark",
"//compression:compress", "//compression:compress",
"@highway//:hwy", "@highway//:hwy",
@ -522,6 +557,7 @@ cc_binary(
":gemma_lib", ":gemma_lib",
":ops", ":ops",
":threading_context", ":threading_context",
":tokenizer",
"//compression:shared", "//compression:shared",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",

View File

@ -39,6 +39,7 @@ set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL) FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(benchmark) FetchContent_MakeAvailable(benchmark)
# Base source files
set(SOURCES set(SOURCES
compression/blob_store.cc compression/blob_store.cc
compression/blob_store.h compression/blob_store.h
@ -115,6 +116,17 @@ set(SOURCES
util/topology.h util/topology.h
) )
# Add C API sources only when building DLL
if(BUILD_GEMMA_DLL)
list(APPEND SOURCES
gemma/bindings/context.h
gemma/bindings/context.cc
gemma/bindings/c_api.h
gemma/bindings/c_api.cc
)
message(STATUS "Including C API files for DLL build")
endif()
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
endif() endif()
@ -134,6 +146,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>) target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib) install(TARGETS libgemma DESTINATION lib)
# Shared library target for C# interop
if(BUILD_GEMMA_DLL)
add_library(gemma_shared SHARED ${SOURCES})
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
set_target_properties(gemma_shared PROPERTIES
PREFIX ""
OUTPUT_NAME "gemma"
)
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(gemma_shared PUBLIC ./)
target_link_libraries(gemma_shared PRIVATE
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
)
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma_shared
PRIVATE
GEMMA_EXPORTS
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS gemma_shared DESTINATION lib)
install(FILES gemma/c_api.h DESTINATION include/gemma)
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
endif()
# Executable Target # Executable Target
add_executable(gemma gemma/run.cc) add_executable(gemma gemma/run.cc)

View File

@ -31,6 +31,24 @@
"lhs": "${hostSystemName}", "lhs": "${hostSystemName}",
"rhs": "Windows" "rhs": "Windows"
} }
},
{
"name": "windows-dll",
"inherits": "__defaults__",
"displayName": "Windows DLL",
"description": "Visual Studio 2022 with Clang/LLVM frontend (DLL build)",
"generator": "Visual Studio 17 2022",
"toolset": "ClangCL",
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Windows"
},
"cacheVariables": {
"BUILD_SHARED_LIBS": "OFF",
"CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS": "ON",
"BUILD_GEMMA_DLL": "ON"
}
} }
], ],
"buildPresets": [ "buildPresets": [
@ -54,6 +72,15 @@
"displayName": "Windows", "displayName": "Windows",
"configuration": "Release", "configuration": "Release",
"configurePreset": "windows" "configurePreset": "windows"
},
{
"name": "windows-dll",
"displayName": "Windows DLL",
"configuration": "Release",
"configurePreset": "windows-dll",
"targets": [
"gemma_shared"
]
} }
] ]
} }

View File

@ -25,6 +25,7 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -54,8 +55,9 @@ class GemmaEnv {
size_t MaxGeneratedTokens() const { size_t MaxGeneratedTokens() const {
return runtime_config_.max_generated_tokens; return runtime_config_.max_generated_tokens;
} }
void SetMaxGeneratedTokens(size_t max_generated_tokens) { void SetMaxGeneratedTokens(int max_generated_tokens) {
runtime_config_.max_generated_tokens = max_generated_tokens; runtime_config_.max_generated_tokens =
static_cast<size_t>(max_generated_tokens);
} }
std::vector<int> Tokenize(const std::string& input) const { std::vector<int> Tokenize(const std::string& input) const {

View File

@ -0,0 +1,426 @@
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
namespace GemmaCpp
{
public class GemmaException : Exception
{
public GemmaException(string message) : base(message) { }
}
public class Gemma : IDisposable
{
private IntPtr _context;
private bool _disposed;
// Optional: Allow setting DLL path
public static string DllPath { get; set; } = "gemma.dll";
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern IntPtr LoadLibrary(string lpFileName);
static Gemma()
{
// Load DLL from specified path
if (LoadLibrary(DllPath) == IntPtr.Zero)
{
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
}
}
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr GemmaCreate(
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
int maxLength);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaDestroy(IntPtr context);
// Delegate type for token callbacks
public delegate bool TokenCallback(string token);
// Keep delegate alive for duration of calls
private GCHandle _callbackHandle;
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate bool GemmaTokenCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaGenerate(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
[Out] byte[] output,
int maxLength,
GemmaTokenCallback callback,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaGenerateMultimodal(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
IntPtr image_data, // Renamed param to match C API
int image_width, // Added dimension
int image_height, // Added dimension
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
int maxLength,
GemmaTokenCallback callback,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaCountTokens(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
// Configuration function imports
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetMaxGeneratedTokens(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetMultiturn(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetTemperature(IntPtr context, float value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetTopK(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetDeterministic(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetPrefillTbatchSize(IntPtr context, int value);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaResetConversation")]
private static extern void GemmaResetConversation(IntPtr context);
// Conversation management function imports
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaCreateConversation")]
private static extern int GemmaCreateConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSwitchConversation")]
private static extern int GemmaSwitchConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaDeleteConversation")]
private static extern int GemmaDeleteConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaHasConversation")]
private static extern int GemmaHasConversation(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
// Native callback delegate type
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void GemmaLogCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
IntPtr userData);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetLogCallback(
IntPtr context,
GemmaLogCallback callback,
IntPtr userData);
private GCHandle _logCallbackHandle;
private bool _loggingEnabled = false;
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
{
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
if (_context == IntPtr.Zero)
{
throw new GemmaException("Failed to create Gemma context");
}
}
// Enable debug logging
public void EnableLogging(bool enable = true)
{
if (enable && !_loggingEnabled)
{
GemmaLogCallback logCallback = (message, _) =>
{
Debug.WriteLine($"Gemma: {message}");
};
_logCallbackHandle = GCHandle.Alloc(logCallback);
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
_loggingEnabled = true;
}
else if (!enable && _loggingEnabled)
{
if (_logCallbackHandle.IsAllocated)
_logCallbackHandle.Free();
GemmaSetLogCallback(_context, null, IntPtr.Zero);
_loggingEnabled = false;
}
}
// Configuration methods
public void SetMultiturn(bool enable)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetMultiturn(_context, enable ? 1 : 0);
Debug.WriteLine($"Gemma: Set multiturn to {(enable ? "enabled" : "disabled")}");
}
public void SetTemperature(float temperature)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetTemperature(_context, temperature);
Debug.WriteLine($"Gemma: Set temperature to {temperature}");
}
public void SetTopK(int topK)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetTopK(_context, topK);
Debug.WriteLine($"Gemma: Set topK to {topK}");
}
public void SetDeterministic(bool deterministic)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSetDeterministic(_context, deterministic ? 1 : 0);
Debug.WriteLine($"Gemma: Set deterministic to {(deterministic ? "true" : "false")}");
}
// Renamed public method
public void ResetConversation()
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaResetConversation(_context); // Call P/Invoke method
Debug.WriteLine("Gemma: Reset active conversation");
}
// Conversation management methods
public bool CreateConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaCreateConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Create conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
return result;
}
public bool SwitchConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaSwitchConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Switch to conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
return result;
}
public bool DeleteConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaDeleteConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Delete conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
return result;
}
public bool HasConversation(string conversationName)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
bool result = GemmaHasConversation(_context, conversationName) != 0; // Call P/Invoke method
Debug.WriteLine($"Gemma: Has conversation '{conversationName}' - {result}");
return result;
}
public int CountTokens(string prompt)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
int count = GemmaCountTokens(_context, prompt);
return count;
}
public string Generate(string prompt, int maxLength = 4096)
{
return Generate(prompt, null, maxLength);
}
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
var outputBuffer = new byte[maxLength * 4]; // Allow for worst case UTF-8 size
GemmaTokenCallback nativeCallback = null;
// Track token count for debugging
int tokenCount = 0;
if (callback != null)
{
nativeCallback = (text, _) =>
{
tokenCount++;
// Log token for debugging
Debug.WriteLine($"Token {tokenCount}: '{text}'");
// Pass token to user callback
return callback(text);
};
_callbackHandle = GCHandle.Alloc(nativeCallback);
}
try
{
int length = GemmaGenerate(_context, prompt, outputBuffer, maxLength,
nativeCallback, IntPtr.Zero);
if (length < 0)
throw new GemmaException("Generation failed");
Debug.WriteLine($"Generation complete: {tokenCount} tokens processed, result length: {length}");
// Convert the byte buffer to a string using UTF-8 encoding
string result = Encoding.UTF8.GetString(outputBuffer, 0, length);
return result;
}
finally
{
if (_callbackHandle.IsAllocated)
_callbackHandle.Free();
}
}
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxLength = 4096)
{
// Pass width and height to the overloaded method
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxLength);
}
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxLength = 4096)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
if (imageData == null || imageData.Length == 0)
throw new ArgumentException("Image data cannot be null or empty", nameof(imageData));
if (imageWidth <= 0 || imageHeight <= 0)
throw new ArgumentException("Image dimensions must be positive");
if (imageData.Length < imageWidth * imageHeight * 3)
throw new ArgumentException("Image data array is too small for the specified dimensions");
var output = new StringBuilder(maxLength);
GemmaTokenCallback nativeCallback = null;
if (callback != null)
{
nativeCallback = (text, _) => callback(text);
_callbackHandle = GCHandle.Alloc(nativeCallback);
}
// Pin the image data so it doesn't move during the native call
GCHandle imageHandle = GCHandle.Alloc(imageData, GCHandleType.Pinned);
try
{
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
// Pass image dimensions to the native call
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxLength,
nativeCallback, IntPtr.Zero);
if (length < 0)
throw new GemmaException("Multimodal generation failed");
return output.ToString();
}
finally
{
imageHandle.Free();
if (_callbackHandle.IsAllocated)
_callbackHandle.Free();
}
}
public void Dispose()
{
if (!_disposed)
{
if (_context != IntPtr.Zero)
{
GemmaDestroy(_context);
_context = IntPtr.Zero;
}
if (_logCallbackHandle.IsAllocated)
_logCallbackHandle.Free();
_disposed = true;
}
}
~Gemma()
{
Dispose();
}
}
}

128
gemma/bindings/c_api.cc Normal file
View File

@ -0,0 +1,128 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GEMMA_EXPORTS
#define GEMMA_EXPORTS
#endif
#include "gemma/bindings/c_api.h"
extern "C" {
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length) {
try {
GemmaContext* ctx = GemmaContext::Create(
tokenizer_path, model_type, weights_path, weight_type, max_length);
return ctx;
} catch (...) {
return nullptr;
}
}
GEMMA_API void GemmaDestroy(GemmaContext* ctx) { delete ctx; }
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data) {
if (!ctx) return -1;
return ctx->Generate(prompt, output, max_length, callback, user_data);
}
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
const void* image_data, int image_width,
int image_height, char* output,
int max_length,
GemmaTokenCallback callback,
void* user_data) {
if (!ctx) return -1;
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
output, max_length, callback, user_data);
}
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
if (!ctx || !text) return -1;
return ctx->CountTokens(text);
}
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data) {
if (!ctx) return;
ctx->SetLogCallback(callback, user_data);
}
// Configuration functions implementation
GEMMA_API void GemmaSetMaxGeneratedTokens(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetMaxGeneratedTokens(value);
}
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetMultiturn(value);
}
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value) {
if (!ctx) return;
ctx->SetTemperature(value);
}
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetTopK(value);
}
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetDeterministic(value != 0);
}
GEMMA_API void GemmaSetPrefillTbatchSize(GemmaContext* ctx, int value) {
if (!ctx) return;
ctx->SetPrefillTbatchSize(value);
}
GEMMA_API void GemmaResetConversation(GemmaContext* ctx) { // Renamed function
if (!ctx) return;
ctx->ResetConversation();
}
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->CreateConversation(conversation_name) ? 1 : 0;
}
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->SwitchConversation(conversation_name) ? 1 : 0;
}
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->DeleteConversation(conversation_name) ? 1 : 0;
}
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
const char* conversation_name) {
if (!ctx || !conversation_name) return 0;
return ctx->HasConversation(conversation_name) ? 1 : 0;
}
}

86
gemma/bindings/c_api.h Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_C_API_H_
#define THIRD_PARTY_GEMMA_C_API_H_
#include "gemma/bindings/context.h"
#ifdef _WIN32
#ifdef GEMMA_EXPORTS
#define GEMMA_API __declspec(dllexport)
#else
#define GEMMA_API __declspec(dllimport)
#endif
#else
#define GEMMA_API __attribute__((visibility("default")))
#endif
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __cplusplus
typedef gcpp::GemmaContext GemmaContext;
#else
typedef struct GemmaContext GemmaContext;
#endif
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length);
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data);
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
const void* image_data, // Renamed param
int image_width, // Added dimension
int image_height, // Added dimension
char* output, int max_length,
GemmaTokenCallback callback,
void* user_data);
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data);
// Configuration functions
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Renamed
// Conversation management functions (renamed)
GEMMA_API int GemmaCreateConversation(
GemmaContext* ctx, const char* conversation_name); // Renamed
GEMMA_API int GemmaSwitchConversation(
GemmaContext* ctx, const char* conversation_name); // Renamed
GEMMA_API int GemmaDeleteConversation(
GemmaContext* ctx, const char* conversation_name); // Renamed
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
const char* conversation_name); // Renamed
#ifdef __cplusplus
}
#endif
#endif // THIRD_PARTY_GEMMA_C_API_H_

331
gemma/bindings/context.cc Normal file
View File

@ -0,0 +1,331 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/bindings/context.h"
#include <cstddef>
#include <cstring>
#include <memory>
#include <sstream>
#include <vector>
#include "evals/benchmark_helper.h" // InitGenerator
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
#ifdef _WIN32
#include <Windows.h>
#endif
#include "gemma/kv_cache.h"
#include "paligemma/image.h"
namespace gcpp {
// ConversationData constructor implementation
ConversationData::ConversationData(const ModelConfig& model_config,
size_t prefill_tbatch_size)
: kv_cache(std::make_unique<KVCache>(
KVCache::Create(model_config, prefill_tbatch_size))),
abs_pos(0) {}
// Initialize static members
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
void* GemmaContext::s_log_user_data = nullptr;
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length) {
std::stringstream ss;
ss << "Creating GemmaContext with tokenizer_path: "
<< (tokenizer_path ? tokenizer_path : "null")
<< ", model_type: " << (model_type ? model_type : "null")
<< ", weights_path: " << (weights_path ? weights_path : "null")
<< ", weight_type: " << (weight_type ? weight_type : "null")
<< ", max_length: " << max_length;
LogDebug(ss.str().c_str());
ThreadingArgs threading_args;
threading_args.spin = gcpp::Tristate::kFalse;
LoaderArgs loader(tokenizer_path, weights_path, model_type);
loader.weight_type_str = weight_type;
LogDebug("LoaderArgs created");
if (const char* error = loader.Validate()) {
ss.str("");
ss << "Invalid loader configuration: " << error;
LogDebug(ss.str().c_str());
HWY_ABORT("Invalid loader configuration: %s", error);
}
LogDebug("Loader validated successfully");
// Initialize cached args
LogDebug("Initializing inference args");
InferenceArgs inference_args;
inference_args.Init();
inference_args.max_generated_tokens = max_length;
inference_args.temperature = 0.7f;
inference_args.top_k = 1;
inference_args.deterministic = false;
ss.str("");
ss << "Inference args initialized with max_tokens: " << max_length
<< ", temperature: " << inference_args.temperature
<< ", top_k: " << inference_args.top_k << ", deterministic: "
<< (inference_args.deterministic ? "true" : "false");
LogDebug(ss.str().c_str());
return new GemmaContext(loader, inference_args, threading_args, max_length);
}
GemmaContext::GemmaContext(const LoaderArgs& loader,
const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_length)
: inference_args(inference_args),
threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args)),
model(CreateGemma(loader, matmul_env)) {
std::stringstream ss;
LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args.prefill_tbatch_size);
LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]");
// Store the shared_ptr in the map under the "default" key
conversation_cache["default"] = active_conversation;
LogDebug("GemmaContext constructor completed");
}
// Internal implementation shared by Generate and GenerateMultimodal
int GemmaContext::GenerateInternal(const char* prompt_string,
const void* image_data, int image_width,
int image_height, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data) {
PROFILER_ZONE("Gen.Internal");
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
std::stringstream ss;
result_buffer.clear();
InitGenerator(inference_args, gen);
// Ensure we have an active conversation
if (!active_conversation || !active_conversation->kv_cache) {
LogDebug("Generate called with null active_conversation or kv_cache");
return -1;
}
// callback function invoked for each generated token.
auto stream_token = [&, callback, user_data](int token, float) {
// Use abs_pos from the active conversation
++(active_conversation->abs_pos);
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
if (in_prompt || model.GetModelConfig().IsEOS(token)) {
return true;
}
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
}
// if we have a managed callback, pass it the token text
if (callback) {
if (!callback(token_text.c_str(), user_data)) {
LogDebug("Callback returned false, stopping generation");
return false;
}
}
result_buffer.append(token_text);
return true;
};
// set up runtime config
TimingInfo timing_info = {};
RuntimeConfig runtime_config = {.gen = &gen,
.stream_token = stream_token,
.use_spinning = threading_args.spin};
inference_args.CopyTo(runtime_config);
size_t prefix_end = 0;
// generate
std::vector<int> prompt;
ImageTokens image_tokens;
if (image_data != nullptr) {
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
image_tokens =
ImageTokens(model.Env().ctx.allocator,
Extents2D(model.GetModelConfig().vit_config.seq_len /
(pool_dim * pool_dim),
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
Image image;
image.Set(image_width, image_height, static_cast<const float*>(image_data));
// We may need to resize the supplied image depending on whether we're using
// PaliGemma or Gemma 3.
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
// Use the existing runtime_config defined earlier in the function.
// RuntimeConfig runtime_config = { ... }; // This was already defined
double image_tokens_start = hwy::platform::Now();
// Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config, image, image_tokens);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
ss.str("");
ss << "\n\n[ Timing info ] Image token generation took: ";
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
LogDebug(ss.str().c_str());
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), active_conversation->abs_pos,
prompt_string, image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
} else {
// Text-only case (original logic)
// Use abs_pos from the active conversation
prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
active_conversation->abs_pos, prompt_string);
prompt_size = prompt.size();
}
// Check if prompt generation failed (e.g., multimodal not implemented yet)
if (prompt.empty() && image_data != nullptr) {
// Already logged the error, just ensure we don't proceed.
return -1;
}
// Pass the KVCache object by reference from the active conversation
model.Generate(runtime_config, prompt, active_conversation->abs_pos,
prefix_end, *(active_conversation->kv_cache), timing_info);
// prepare for next turn
if (!inference_args.multiturn ||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position.
active_conversation->abs_pos = 0;
InitGenerator(inference_args, gen);
} else {
// Multi-turn Gemma: Rewind position in the active conversation
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
// https://arxiv.org/pdf/2408.00118
// Or we have hit max_generated_tokens, then the last token will be lost.
// (We could store it in stream_token, and then prepend to the next turn,
// but it's not worth the complexity, as multi-turn with max_generated is
// not a common use case.)
// In either case, we need to rewind the active conversation's abs_pos by
// one.
HWY_ASSERT(active_conversation->abs_pos > 0);
active_conversation->abs_pos--;
}
// Copy result buffer to output C-string (ensure null termination)
strncpy(output, result_buffer.c_str(), max_length - 1);
output[max_length - 1] = '\0'; // Explicit null termination
return static_cast<int>(strlen(output)); // Return length of the C-string
}
// Public Generate method (wrapper for text-only)
int GemmaContext::Generate(const char* prompt_string, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data) {
// Call the internal implementation with null image_data and 0 dimensions
return GenerateInternal(prompt_string, nullptr, 0, 0, output, max_length,
callback, user_data);
}
// Public GenerateMultimodal method (wrapper)
int GemmaContext::GenerateMultimodal(const char* prompt_string,
const void* image_data, int image_width,
int image_height, // Added dimensions
char* output, int max_length,
GemmaTokenCallback callback,
void* user_data) {
if (image_data == nullptr) {
LogDebug(
"GenerateMultimodal called with null image_data. Use Generate for "
"text-only.");
// Or potentially call GenerateInternal with null image_data anyway?
// Returning error seems safer.
return -1;
}
return GenerateInternal(prompt_string, image_data, image_width, image_height,
output, max_length, callback, user_data);
}
int GemmaContext::CountTokens(const char* text) {
LogDebug("CountTokens method started");
std::stringstream ss;
ss << "CountTokens called with text: '" << (text ? text : "null") << "'";
LogDebug(ss.str().c_str());
if (!text) {
LogDebug("CountTokens failed: Invalid parameters");
if (!text) LogDebug(" text is null");
return -1;
}
try {
LogDebug("Creating text string");
std::string text_str(text);
LogDebug("Creating tokens vector");
std::vector<int> tokens;
LogDebug("Encoding text to tokens");
HWY_ASSERT(model.Tokenizer().Encode(text_str, &tokens));
ss.str("");
ss << "Text tokenized into " << tokens.size() << " tokens";
LogDebug(ss.str().c_str());
LogDebug("CountTokens completed successfully");
return static_cast<int>(tokens.size());
} catch (...) {
LogDebug("Unknown exception in CountTokens");
return -1;
}
}
} // namespace gcpp

249
gemma/bindings/context.h Normal file
View File

@ -0,0 +1,249 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#include <memory> // For std::shared_ptr, std::make_shared
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
// Logging
#ifdef _WIN32
#include <windows.h>
#else
#include <stdio.h>
#endif
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "ops/matmul.h" // MatMulEnv
#include "hwy/base.h"
#include "hwy/highway.h"
namespace gcpp {
// Forward declaration - use 'struct' to match definition tag
struct KVCache;
// Struct to hold data for a single conversation thread
struct ConversationData {
std::unique_ptr<KVCache> kv_cache;
size_t abs_pos = 0;
// Constructor to initialize kv_cache (requires KVCache definition or forward
// declaration)
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
};
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext {
private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_length);
public:
static GemmaContext* Create(const char* tokenizer_path,
const char* model_type, const char* weights_path,
const char* weight_type, int max_length);
// Returns length of generated text, or -1 on error
int Generate(const char* prompt_string, char* output, int max_length,
GemmaTokenCallback callback, void* user_data);
// Returns length of generated text, or -1 on error
int GenerateMultimodal(const char* prompt_string, const void* image_data,
int image_width, int image_height, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data);
// Returns number of tokens in text, or -1 on error
int CountTokens(const char* text);
// Add new method to set logger
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
s_log_callback = callback;
s_log_user_data = user_data;
}
// Set max generated tokens
void SetMaxGeneratedTokens(size_t value) {
inference_args.max_generated_tokens = value;
LogDebug("Setting max_generated_tokens to configured value");
}
// Set multiturn flag (0 = disabled, 1 = enabled)
void SetMultiturn(int value) {
inference_args.multiturn = value;
LogDebug("Setting multiturn to configured value");
}
// Set temperature for token generation
void SetTemperature(float value) {
inference_args.temperature = value;
LogDebug("Setting temperature to configured value");
}
// Set top_k parameter for sampling
void SetTopK(int value) {
inference_args.top_k = value;
LogDebug("Setting top_k to configured value");
}
// Set deterministic flag
void SetDeterministic(bool value) {
inference_args.deterministic = value;
// Reset the random number generator for deterministic generation
if (value) {
gen.seed(0x87654321);
}
LogDebug("Setting deterministic flag to configured value");
}
// Set prefill_tbatch_size
void SetPrefillTbatchSize(size_t value) {
inference_args.prefill_tbatch_size = value;
LogDebug("Setting prefill_tbatch_size to configured value");
}
// Reset the currently active conversation
void ResetConversation() {
if (active_conversation) {
LogDebug("Resetting active conversation");
active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
model.GetModelConfig(), inference_args.prefill_tbatch_size));
LogDebug("Active conversation reset");
} else {
LogDebug("Cannot reset conversation: active_conversation is null");
}
}
// Create a new named conversation
bool CreateConversation(const char* conversation_name) {
std::string name(conversation_name);
if (conversation_cache.count(name)) {
LogDebug("Conversation already exists");
return false;
}
LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args.prefill_tbatch_size);
return true;
}
// Switch to a named conversation
bool SwitchConversation(const char* conversation_name) {
std::string name(conversation_name);
auto it = conversation_cache.find(name);
if (it == conversation_cache.end()) {
LogDebug("Conversation not found");
return false;
}
LogDebug("Switching active conversation");
active_conversation = it->second;
return true;
}
// Delete a named conversation
bool DeleteConversation(const char* conversation_name) {
std::string name(conversation_name);
auto it = conversation_cache.find(name);
if (it == conversation_cache.end()) {
LogDebug("Conversation not found for deletion");
return false;
}
if (name == "default") {
LogDebug("Cannot delete the default conversation");
return false;
}
if (it->second == active_conversation) {
LogDebug("Cannot delete the currently active conversation");
return false;
}
LogDebug("Deleting conversation");
conversation_cache.erase(it);
return true;
}
// Check if a named conversation exists
bool HasConversation(const char* conversation_name) {
std::string name(conversation_name);
return conversation_cache.count(name);
}
private:
// Internal implementation shared by Generate and GenerateMultimodal
int GenerateInternal(const char* prompt_string,
const void* image_data, // Null for text-only generation
int image_width, // Added dimension (0 if no image)
int image_height, // Added dimension (0 if no image)
char* output, int max_length,
GemmaTokenCallback callback, void* user_data);
// Pointer to the currently active conversation's data
std::shared_ptr<ConversationData> active_conversation;
// Cache of all named conversations
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
conversation_cache;
// Buffers (potentially could be moved into ConversationData if needed
// per-conversation)
std::string prompt_buffer;
std::string result_buffer;
std::vector<int> token_buffer;
// Cached args (remain global for the context)
InferenceArgs inference_args;
ThreadingArgs threading_args;
MatMulEnv matmul_env;
// Model itself (don't move this, needs to be below the args above)
Gemma model;
// Random generator (remains global for the context)
std::mt19937 gen;
// Static members for logging
static GemmaLogCallback s_log_callback;
static void* s_log_user_data;
// Use logging helper method to print messages into a managed callback if
// necessary
static void LogDebug(const char* message) {
if (s_log_callback) {
s_log_callback(message, s_log_user_data);
} else {
#ifdef _WIN32
OutputDebugStringA(message);
#else
printf("%s", message);
#endif
}
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_

View File

@ -271,12 +271,6 @@ class Gemma {
ModelWeightsStorage model_; ModelWeightsStorage model_;
}; };
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt);
void RangeChecks(const ModelConfig& weights_config, void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, size_t prompt_size); size_t& max_generated_tokens, size_t prompt_size);

View File

@ -28,6 +28,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // Gemma #include "gemma/gemma.h" // Gemma
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"