mirror of https://github.com/google/gemma.cpp.git
Add C API and C# interop files
This change adds a basic C API that allows access to Gemma functionality from other programming languages. The functionality is exposed via a shared library (DLL on Windows), with C++ interfaces and a basic C# interop wrapper included. To build the DLL, use the `windows-dll` preset, which includes the C and C++ sources as follows: ``` cmake --preset windows-dll cmake --build --config Release --preset windows-dll -j 4 ``` This should generate a `gemma.dll` in `<build-dir>/Release`. To build for non-Windows, the appropriate C++ DLL linking will need to be done to generate a shared library for the target OS. PiperOrigin-RevId: 750246272
This commit is contained in:
parent
f20da328de
commit
ba10c88a94
36
BUILD.bazel
36
BUILD.bazel
|
|
@ -428,6 +428,40 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_shared_lib",
|
||||
srcs = [
|
||||
"gemma/bindings/c_api.cc",
|
||||
"gemma/bindings/context.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/bindings/c_api.h",
|
||||
"gemma/bindings/context.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":ops",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
":weights",
|
||||
"//compression:shared",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cross_entropy",
|
||||
srcs = ["evals/cross_entropy.cc"],
|
||||
|
|
@ -465,6 +499,7 @@ cc_library(
|
|||
":gemma_lib",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"@google_benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -522,6 +557,7 @@ cc_binary(
|
|||
":gemma_lib",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"//compression:shared",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
|
|||
FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(benchmark)
|
||||
|
||||
# Base source files
|
||||
set(SOURCES
|
||||
compression/blob_store.cc
|
||||
compression/blob_store.h
|
||||
|
|
@ -115,6 +116,17 @@ set(SOURCES
|
|||
util/topology.h
|
||||
)
|
||||
|
||||
# Add C API sources only when building DLL
|
||||
if(BUILD_GEMMA_DLL)
|
||||
list(APPEND SOURCES
|
||||
gemma/bindings/context.h
|
||||
gemma/bindings/context.cc
|
||||
gemma/bindings/c_api.h
|
||||
gemma/bindings/c_api.cc
|
||||
)
|
||||
message(STATUS "Including C API files for DLL build")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
|
@ -134,6 +146,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
|
|||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
install(TARGETS libgemma DESTINATION lib)
|
||||
|
||||
# Shared library target for C# interop
|
||||
if(BUILD_GEMMA_DLL)
|
||||
add_library(gemma_shared SHARED ${SOURCES})
|
||||
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gemma_shared PROPERTIES
|
||||
PREFIX ""
|
||||
OUTPUT_NAME "gemma"
|
||||
)
|
||||
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(gemma_shared PUBLIC ./)
|
||||
target_link_libraries(gemma_shared PRIVATE
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
|
||||
)
|
||||
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(gemma_shared
|
||||
PRIVATE
|
||||
GEMMA_EXPORTS
|
||||
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
|
||||
)
|
||||
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
install(TARGETS gemma_shared DESTINATION lib)
|
||||
install(FILES gemma/c_api.h DESTINATION include/gemma)
|
||||
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
|
||||
endif()
|
||||
|
||||
# Executable Target
|
||||
|
||||
add_executable(gemma gemma/run.cc)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,24 @@
|
|||
"lhs": "${hostSystemName}",
|
||||
"rhs": "Windows"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "windows-dll",
|
||||
"inherits": "__defaults__",
|
||||
"displayName": "Windows DLL",
|
||||
"description": "Visual Studio 2022 with Clang/LLVM frontend (DLL build)",
|
||||
"generator": "Visual Studio 17 2022",
|
||||
"toolset": "ClangCL",
|
||||
"condition": {
|
||||
"type": "equals",
|
||||
"lhs": "${hostSystemName}",
|
||||
"rhs": "Windows"
|
||||
},
|
||||
"cacheVariables": {
|
||||
"BUILD_SHARED_LIBS": "OFF",
|
||||
"CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS": "ON",
|
||||
"BUILD_GEMMA_DLL": "ON"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
|
|
@ -54,6 +72,15 @@
|
|||
"displayName": "Windows",
|
||||
"configuration": "Release",
|
||||
"configurePreset": "windows"
|
||||
},
|
||||
{
|
||||
"name": "windows-dll",
|
||||
"displayName": "Windows DLL",
|
||||
"configuration": "Release",
|
||||
"configurePreset": "windows-dll",
|
||||
"targets": [
|
||||
"gemma_shared"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "ops/matmul.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -54,8 +55,9 @@ class GemmaEnv {
|
|||
size_t MaxGeneratedTokens() const {
|
||||
return runtime_config_.max_generated_tokens;
|
||||
}
|
||||
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
|
||||
runtime_config_.max_generated_tokens = max_generated_tokens;
|
||||
void SetMaxGeneratedTokens(int max_generated_tokens) {
|
||||
runtime_config_.max_generated_tokens =
|
||||
static_cast<size_t>(max_generated_tokens);
|
||||
}
|
||||
|
||||
std::vector<int> Tokenize(const std::string& input) const {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,426 @@
|
|||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
namespace GemmaCpp
|
||||
{
|
||||
public class GemmaException : Exception
|
||||
{
|
||||
public GemmaException(string message) : base(message) { }
|
||||
}
|
||||
|
||||
public class Gemma : IDisposable
|
||||
{
|
||||
private IntPtr _context;
|
||||
private bool _disposed;
|
||||
|
||||
// Optional: Allow setting DLL path
|
||||
public static string DllPath { get; set; } = "gemma.dll";
|
||||
|
||||
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
|
||||
private static extern IntPtr LoadLibrary(string lpFileName);
|
||||
|
||||
static Gemma()
|
||||
{
|
||||
// Load DLL from specified path
|
||||
if (LoadLibrary(DllPath) == IntPtr.Zero)
|
||||
{
|
||||
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
|
||||
}
|
||||
}
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern IntPtr GemmaCreate(
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
||||
int maxLength);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaDestroy(IntPtr context);
|
||||
|
||||
// Delegate type for token callbacks
|
||||
public delegate bool TokenCallback(string token);
|
||||
|
||||
// Keep delegate alive for duration of calls
|
||||
private GCHandle _callbackHandle;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
private delegate bool GemmaTokenCallback(
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern int GemmaGenerate(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
[Out] byte[] output,
|
||||
int maxLength,
|
||||
GemmaTokenCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern int GemmaGenerateMultimodal(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
IntPtr image_data, // Renamed param to match C API
|
||||
int image_width, // Added dimension
|
||||
int image_height, // Added dimension
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
|
||||
int maxLength,
|
||||
GemmaTokenCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern int GemmaCountTokens(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
|
||||
|
||||
// Configuration function imports
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetMaxGeneratedTokens(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetMultiturn(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetTemperature(IntPtr context, float value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetTopK(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetDeterministic(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetPrefillTbatchSize(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaResetConversation")]
|
||||
private static extern void GemmaResetConversation(IntPtr context);
|
||||
|
||||
// Conversation management function imports
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaCreateConversation")]
|
||||
private static extern int GemmaCreateConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSwitchConversation")]
|
||||
private static extern int GemmaSwitchConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaDeleteConversation")]
|
||||
private static extern int GemmaDeleteConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaHasConversation")]
|
||||
private static extern int GemmaHasConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
// Native callback delegate type
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
private delegate void GemmaLogCallback(
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetLogCallback(
|
||||
IntPtr context,
|
||||
GemmaLogCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
private GCHandle _logCallbackHandle;
|
||||
private bool _loggingEnabled = false;
|
||||
|
||||
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
|
||||
{
|
||||
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
|
||||
if (_context == IntPtr.Zero)
|
||||
{
|
||||
throw new GemmaException("Failed to create Gemma context");
|
||||
}
|
||||
}
|
||||
|
||||
// Enable debug logging
|
||||
public void EnableLogging(bool enable = true)
|
||||
{
|
||||
if (enable && !_loggingEnabled)
|
||||
{
|
||||
GemmaLogCallback logCallback = (message, _) =>
|
||||
{
|
||||
Debug.WriteLine($"Gemma: {message}");
|
||||
};
|
||||
_logCallbackHandle = GCHandle.Alloc(logCallback);
|
||||
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
|
||||
_loggingEnabled = true;
|
||||
}
|
||||
else if (!enable && _loggingEnabled)
|
||||
{
|
||||
if (_logCallbackHandle.IsAllocated)
|
||||
_logCallbackHandle.Free();
|
||||
GemmaSetLogCallback(_context, null, IntPtr.Zero);
|
||||
_loggingEnabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration methods
|
||||
public void SetMultiturn(bool enable)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetMultiturn(_context, enable ? 1 : 0);
|
||||
Debug.WriteLine($"Gemma: Set multiturn to {(enable ? "enabled" : "disabled")}");
|
||||
}
|
||||
|
||||
public void SetTemperature(float temperature)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetTemperature(_context, temperature);
|
||||
Debug.WriteLine($"Gemma: Set temperature to {temperature}");
|
||||
}
|
||||
|
||||
public void SetTopK(int topK)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetTopK(_context, topK);
|
||||
Debug.WriteLine($"Gemma: Set topK to {topK}");
|
||||
}
|
||||
|
||||
public void SetDeterministic(bool deterministic)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetDeterministic(_context, deterministic ? 1 : 0);
|
||||
Debug.WriteLine($"Gemma: Set deterministic to {(deterministic ? "true" : "false")}");
|
||||
}
|
||||
|
||||
// Renamed public method
|
||||
public void ResetConversation()
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaResetConversation(_context); // Call P/Invoke method
|
||||
Debug.WriteLine("Gemma: Reset active conversation");
|
||||
}
|
||||
|
||||
// Conversation management methods
|
||||
public bool CreateConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaCreateConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Create conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool SwitchConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaSwitchConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Switch to conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool DeleteConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaDeleteConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Delete conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool HasConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaHasConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Has conversation '{conversationName}' - {result}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public int CountTokens(string prompt)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
int count = GemmaCountTokens(_context, prompt);
|
||||
return count;
|
||||
}
|
||||
|
||||
public string Generate(string prompt, int maxLength = 4096)
|
||||
{
|
||||
return Generate(prompt, null, maxLength);
|
||||
}
|
||||
|
||||
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
var outputBuffer = new byte[maxLength * 4]; // Allow for worst case UTF-8 size
|
||||
GemmaTokenCallback nativeCallback = null;
|
||||
|
||||
// Track token count for debugging
|
||||
int tokenCount = 0;
|
||||
|
||||
if (callback != null)
|
||||
{
|
||||
nativeCallback = (text, _) =>
|
||||
{
|
||||
tokenCount++;
|
||||
// Log token for debugging
|
||||
Debug.WriteLine($"Token {tokenCount}: '{text}'");
|
||||
|
||||
// Pass token to user callback
|
||||
return callback(text);
|
||||
};
|
||||
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
int length = GemmaGenerate(_context, prompt, outputBuffer, maxLength,
|
||||
nativeCallback, IntPtr.Zero);
|
||||
|
||||
if (length < 0)
|
||||
throw new GemmaException("Generation failed");
|
||||
|
||||
Debug.WriteLine($"Generation complete: {tokenCount} tokens processed, result length: {length}");
|
||||
|
||||
// Convert the byte buffer to a string using UTF-8 encoding
|
||||
string result = Encoding.UTF8.GetString(outputBuffer, 0, length);
|
||||
return result;
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (_callbackHandle.IsAllocated)
|
||||
_callbackHandle.Free();
|
||||
}
|
||||
}
|
||||
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxLength = 4096)
|
||||
{
|
||||
// Pass width and height to the overloaded method
|
||||
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxLength);
|
||||
}
|
||||
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxLength = 4096)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
if (imageData == null || imageData.Length == 0)
|
||||
throw new ArgumentException("Image data cannot be null or empty", nameof(imageData));
|
||||
|
||||
if (imageWidth <= 0 || imageHeight <= 0)
|
||||
throw new ArgumentException("Image dimensions must be positive");
|
||||
|
||||
if (imageData.Length < imageWidth * imageHeight * 3)
|
||||
throw new ArgumentException("Image data array is too small for the specified dimensions");
|
||||
|
||||
var output = new StringBuilder(maxLength);
|
||||
GemmaTokenCallback nativeCallback = null;
|
||||
|
||||
if (callback != null)
|
||||
{
|
||||
nativeCallback = (text, _) => callback(text);
|
||||
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
||||
}
|
||||
|
||||
// Pin the image data so it doesn't move during the native call
|
||||
GCHandle imageHandle = GCHandle.Alloc(imageData, GCHandleType.Pinned);
|
||||
|
||||
try
|
||||
{
|
||||
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
|
||||
|
||||
// Pass image dimensions to the native call
|
||||
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxLength,
|
||||
nativeCallback, IntPtr.Zero);
|
||||
|
||||
if (length < 0)
|
||||
throw new GemmaException("Multimodal generation failed");
|
||||
|
||||
return output.ToString();
|
||||
}
|
||||
finally
|
||||
{
|
||||
imageHandle.Free();
|
||||
|
||||
if (_callbackHandle.IsAllocated)
|
||||
_callbackHandle.Free();
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (_context != IntPtr.Zero)
|
||||
{
|
||||
GemmaDestroy(_context);
|
||||
_context = IntPtr.Zero;
|
||||
}
|
||||
if (_logCallbackHandle.IsAllocated)
|
||||
_logCallbackHandle.Free();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
|
||||
~Gemma()
|
||||
{
|
||||
Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GEMMA_EXPORTS
|
||||
#define GEMMA_EXPORTS
|
||||
#endif
|
||||
|
||||
#include "gemma/bindings/c_api.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||
const char* model_type,
|
||||
const char* weights_path,
|
||||
const char* weight_type, int max_length) {
|
||||
try {
|
||||
GemmaContext* ctx = GemmaContext::Create(
|
||||
tokenizer_path, model_type, weights_path, weight_type, max_length);
|
||||
return ctx;
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx) { delete ctx; }
|
||||
|
||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return -1;
|
||||
return ctx->Generate(prompt, output, max_length, callback, user_data);
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_length,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return -1;
|
||||
|
||||
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
|
||||
output, max_length, callback, user_data);
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
|
||||
if (!ctx || !text) return -1;
|
||||
return ctx->CountTokens(text);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return;
|
||||
ctx->SetLogCallback(callback, user_data);
|
||||
}
|
||||
|
||||
// Configuration functions implementation
|
||||
GEMMA_API void GemmaSetMaxGeneratedTokens(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetMaxGeneratedTokens(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetMultiturn(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetTemperature(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetTopK(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetDeterministic(value != 0);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetPrefillTbatchSize(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetPrefillTbatchSize(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx) { // Renamed function
|
||||
if (!ctx) return;
|
||||
ctx->ResetConversation();
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->CreateConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->SwitchConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->DeleteConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->HasConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_C_API_H_
|
||||
#define THIRD_PARTY_GEMMA_C_API_H_
|
||||
|
||||
#include "gemma/bindings/context.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifdef GEMMA_EXPORTS
|
||||
#define GEMMA_API __declspec(dllexport)
|
||||
#else
|
||||
#define GEMMA_API __declspec(dllimport)
|
||||
#endif
|
||||
#else
|
||||
#define GEMMA_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
typedef gcpp::GemmaContext GemmaContext;
|
||||
#else
|
||||
typedef struct GemmaContext GemmaContext;
|
||||
#endif
|
||||
|
||||
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||
|
||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||
const char* model_type,
|
||||
const char* weights_path,
|
||||
const char* weight_type, int max_length);
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
|
||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||
const void* image_data, // Renamed param
|
||||
int image_width, // Added dimension
|
||||
int image_height, // Added dimension
|
||||
char* output, int max_length,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
|
||||
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
|
||||
|
||||
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
|
||||
void* user_data);
|
||||
|
||||
// Configuration functions
|
||||
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
|
||||
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Renamed
|
||||
|
||||
// Conversation management functions (renamed)
|
||||
GEMMA_API int GemmaCreateConversation(
|
||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
||||
GEMMA_API int GemmaSwitchConversation(
|
||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
||||
GEMMA_API int GemmaDeleteConversation(
|
||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
||||
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||
const char* conversation_name); // Renamed
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_C_API_H_
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/bindings/context.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h" // InitGenerator
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "paligemma/image.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// ConversationData constructor implementation
|
||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||
size_t prefill_tbatch_size)
|
||||
: kv_cache(std::make_unique<KVCache>(
|
||||
KVCache::Create(model_config, prefill_tbatch_size))),
|
||||
abs_pos(0) {}
|
||||
|
||||
// Initialize static members
|
||||
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
||||
void* GemmaContext::s_log_user_data = nullptr;
|
||||
|
||||
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||
const char* model_type,
|
||||
const char* weights_path,
|
||||
const char* weight_type, int max_length) {
|
||||
std::stringstream ss;
|
||||
ss << "Creating GemmaContext with tokenizer_path: "
|
||||
<< (tokenizer_path ? tokenizer_path : "null")
|
||||
<< ", model_type: " << (model_type ? model_type : "null")
|
||||
<< ", weights_path: " << (weights_path ? weights_path : "null")
|
||||
<< ", weight_type: " << (weight_type ? weight_type : "null")
|
||||
<< ", max_length: " << max_length;
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
|
||||
LoaderArgs loader(tokenizer_path, weights_path, model_type);
|
||||
loader.weight_type_str = weight_type;
|
||||
LogDebug("LoaderArgs created");
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
ss.str("");
|
||||
ss << "Invalid loader configuration: " << error;
|
||||
LogDebug(ss.str().c_str());
|
||||
HWY_ABORT("Invalid loader configuration: %s", error);
|
||||
}
|
||||
LogDebug("Loader validated successfully");
|
||||
|
||||
// Initialize cached args
|
||||
LogDebug("Initializing inference args");
|
||||
InferenceArgs inference_args;
|
||||
inference_args.Init();
|
||||
inference_args.max_generated_tokens = max_length;
|
||||
inference_args.temperature = 0.7f;
|
||||
inference_args.top_k = 1;
|
||||
inference_args.deterministic = false;
|
||||
|
||||
ss.str("");
|
||||
ss << "Inference args initialized with max_tokens: " << max_length
|
||||
<< ", temperature: " << inference_args.temperature
|
||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
||||
<< (inference_args.deterministic ? "true" : "false");
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
return new GemmaContext(loader, inference_args, threading_args, max_length);
|
||||
}
|
||||
|
||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||
const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_length)
|
||||
: inference_args(inference_args),
|
||||
threading_args(threading_args),
|
||||
matmul_env(MakeMatMulEnv(threading_args)),
|
||||
model(CreateGemma(loader, matmul_env)) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
// Create the initial ConversationData object using make_shared
|
||||
active_conversation = std::make_shared<ConversationData>(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
||||
|
||||
LogDebug(
|
||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||
// Store the shared_ptr in the map under the "default" key
|
||||
conversation_cache["default"] = active_conversation;
|
||||
|
||||
LogDebug("GemmaContext constructor completed");
|
||||
}
|
||||
|
||||
// Internal implementation shared by Generate and GenerateMultimodal
|
||||
int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
PROFILER_ZONE("Gen.Internal");
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
std::stringstream ss;
|
||||
result_buffer.clear();
|
||||
|
||||
InitGenerator(inference_args, gen);
|
||||
|
||||
// Ensure we have an active conversation
|
||||
if (!active_conversation || !active_conversation->kv_cache) {
|
||||
LogDebug("Generate called with null active_conversation or kv_cache");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&, callback, user_data](int token, float) {
|
||||
// Use abs_pos from the active conversation
|
||||
++(active_conversation->abs_pos);
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
if (in_prompt || model.GetModelConfig().IsEOS(token)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
}
|
||||
|
||||
// if we have a managed callback, pass it the token text
|
||||
if (callback) {
|
||||
if (!callback(token_text.c_str(), user_data)) {
|
||||
LogDebug("Callback returned false, stopping generation");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
result_buffer.append(token_text);
|
||||
return true;
|
||||
};
|
||||
|
||||
// set up runtime config
|
||||
TimingInfo timing_info = {};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading_args.spin};
|
||||
inference_args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
// generate
|
||||
std::vector<int> prompt;
|
||||
ImageTokens image_tokens;
|
||||
if (image_data != nullptr) {
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens =
|
||||
ImageTokens(model.Env().ctx.allocator,
|
||||
Extents2D(model.GetModelConfig().vit_config.seq_len /
|
||||
(pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
|
||||
Image image;
|
||||
image.Set(image_width, image_height, static_cast<const float*>(image_data));
|
||||
|
||||
// We may need to resize the supplied image depending on whether we're using
|
||||
// PaliGemma or Gemma 3.
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
|
||||
// Use the existing runtime_config defined earlier in the function.
|
||||
// RuntimeConfig runtime_config = { ... }; // This was already defined
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
// Pass the populated image object to GenerateImageTokens
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
|
||||
ss.str("");
|
||||
ss << "\n\n[ Timing info ] Image token generation took: ";
|
||||
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.Info(), active_conversation->abs_pos,
|
||||
prompt_string, image_tokens.BatchSize());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
} else {
|
||||
// Text-only case (original logic)
|
||||
// Use abs_pos from the active conversation
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
active_conversation->abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
// Check if prompt generation failed (e.g., multimodal not implemented yet)
|
||||
if (prompt.empty() && image_data != nullptr) {
|
||||
// Already logged the error, just ensure we don't proceed.
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Pass the KVCache object by reference from the active conversation
|
||||
model.Generate(runtime_config, prompt, active_conversation->abs_pos,
|
||||
prefix_end, *(active_conversation->kv_cache), timing_info);
|
||||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
active_conversation->abs_pos = 0;
|
||||
InitGenerator(inference_args, gen);
|
||||
} else {
|
||||
// Multi-turn Gemma: Rewind position in the active conversation
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
// https://arxiv.org/pdf/2408.00118
|
||||
// Or we have hit max_generated_tokens, then the last token will be lost.
|
||||
// (We could store it in stream_token, and then prepend to the next turn,
|
||||
// but it's not worth the complexity, as multi-turn with max_generated is
|
||||
// not a common use case.)
|
||||
// In either case, we need to rewind the active conversation's abs_pos by
|
||||
// one.
|
||||
HWY_ASSERT(active_conversation->abs_pos > 0);
|
||||
active_conversation->abs_pos--;
|
||||
}
|
||||
|
||||
// Copy result buffer to output C-string (ensure null termination)
|
||||
strncpy(output, result_buffer.c_str(), max_length - 1);
|
||||
output[max_length - 1] = '\0'; // Explicit null termination
|
||||
|
||||
return static_cast<int>(strlen(output)); // Return length of the C-string
|
||||
}
|
||||
|
||||
// Public Generate method (wrapper for text-only)
|
||||
int GemmaContext::Generate(const char* prompt_string, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
// Call the internal implementation with null image_data and 0 dimensions
|
||||
return GenerateInternal(prompt_string, nullptr, 0, 0, output, max_length,
|
||||
callback, user_data);
|
||||
}
|
||||
|
||||
// Public GenerateMultimodal method (wrapper)
|
||||
int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, // Added dimensions
|
||||
char* output, int max_length,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (image_data == nullptr) {
|
||||
LogDebug(
|
||||
"GenerateMultimodal called with null image_data. Use Generate for "
|
||||
"text-only.");
|
||||
// Or potentially call GenerateInternal with null image_data anyway?
|
||||
// Returning error seems safer.
|
||||
return -1;
|
||||
}
|
||||
|
||||
return GenerateInternal(prompt_string, image_data, image_width, image_height,
|
||||
output, max_length, callback, user_data);
|
||||
}
|
||||
|
||||
int GemmaContext::CountTokens(const char* text) {
|
||||
LogDebug("CountTokens method started");
|
||||
std::stringstream ss;
|
||||
ss << "CountTokens called with text: '" << (text ? text : "null") << "'";
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
if (!text) {
|
||||
LogDebug("CountTokens failed: Invalid parameters");
|
||||
if (!text) LogDebug(" text is null");
|
||||
return -1;
|
||||
}
|
||||
|
||||
try {
|
||||
LogDebug("Creating text string");
|
||||
std::string text_str(text);
|
||||
|
||||
LogDebug("Creating tokens vector");
|
||||
std::vector<int> tokens;
|
||||
|
||||
LogDebug("Encoding text to tokens");
|
||||
HWY_ASSERT(model.Tokenizer().Encode(text_str, &tokens));
|
||||
|
||||
ss.str("");
|
||||
ss << "Text tokenized into " << tokens.size() << " tokens";
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
LogDebug("CountTokens completed successfully");
|
||||
return static_cast<int>(tokens.size());
|
||||
} catch (...) {
|
||||
LogDebug("Unknown exception in CountTokens");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
|
||||
#include <memory> // For std::shared_ptr, std::make_shared
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// Logging
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <stdio.h>
|
||||
#endif
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Forward declaration - use 'struct' to match definition tag
|
||||
struct KVCache;
|
||||
|
||||
// Struct to hold data for a single conversation thread
|
||||
struct ConversationData {
|
||||
std::unique_ptr<KVCache> kv_cache;
|
||||
size_t abs_pos = 0;
|
||||
|
||||
// Constructor to initialize kv_cache (requires KVCache definition or forward
|
||||
// declaration)
|
||||
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
||||
};
|
||||
|
||||
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||
|
||||
class GemmaContext {
|
||||
private:
|
||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_length);
|
||||
|
||||
public:
|
||||
static GemmaContext* Create(const char* tokenizer_path,
|
||||
const char* model_type, const char* weights_path,
|
||||
const char* weight_type, int max_length);
|
||||
|
||||
// Returns length of generated text, or -1 on error
|
||||
int Generate(const char* prompt_string, char* output, int max_length,
|
||||
GemmaTokenCallback callback, void* user_data);
|
||||
// Returns length of generated text, or -1 on error
|
||||
int GenerateMultimodal(const char* prompt_string, const void* image_data,
|
||||
int image_width, int image_height, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
|
||||
// Returns number of tokens in text, or -1 on error
|
||||
int CountTokens(const char* text);
|
||||
|
||||
// Add new method to set logger
|
||||
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
|
||||
s_log_callback = callback;
|
||||
s_log_user_data = user_data;
|
||||
}
|
||||
|
||||
// Set max generated tokens
|
||||
void SetMaxGeneratedTokens(size_t value) {
|
||||
inference_args.max_generated_tokens = value;
|
||||
LogDebug("Setting max_generated_tokens to configured value");
|
||||
}
|
||||
|
||||
// Set multiturn flag (0 = disabled, 1 = enabled)
|
||||
void SetMultiturn(int value) {
|
||||
inference_args.multiturn = value;
|
||||
LogDebug("Setting multiturn to configured value");
|
||||
}
|
||||
|
||||
// Set temperature for token generation
|
||||
void SetTemperature(float value) {
|
||||
inference_args.temperature = value;
|
||||
LogDebug("Setting temperature to configured value");
|
||||
}
|
||||
|
||||
// Set top_k parameter for sampling
|
||||
void SetTopK(int value) {
|
||||
inference_args.top_k = value;
|
||||
LogDebug("Setting top_k to configured value");
|
||||
}
|
||||
|
||||
// Set deterministic flag
|
||||
void SetDeterministic(bool value) {
|
||||
inference_args.deterministic = value;
|
||||
// Reset the random number generator for deterministic generation
|
||||
if (value) {
|
||||
gen.seed(0x87654321);
|
||||
}
|
||||
LogDebug("Setting deterministic flag to configured value");
|
||||
}
|
||||
|
||||
// Set prefill_tbatch_size
|
||||
void SetPrefillTbatchSize(size_t value) {
|
||||
inference_args.prefill_tbatch_size = value;
|
||||
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||
}
|
||||
|
||||
// Reset the currently active conversation
|
||||
void ResetConversation() {
|
||||
if (active_conversation) {
|
||||
LogDebug("Resetting active conversation");
|
||||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||
LogDebug("Active conversation reset");
|
||||
} else {
|
||||
LogDebug("Cannot reset conversation: active_conversation is null");
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new named conversation
|
||||
bool CreateConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
if (conversation_cache.count(name)) {
|
||||
LogDebug("Conversation already exists");
|
||||
return false;
|
||||
}
|
||||
LogDebug("Creating new conversation");
|
||||
// Create a new ConversationData object using make_shared
|
||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Switch to a named conversation
|
||||
bool SwitchConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
auto it = conversation_cache.find(name);
|
||||
if (it == conversation_cache.end()) {
|
||||
LogDebug("Conversation not found");
|
||||
return false;
|
||||
}
|
||||
LogDebug("Switching active conversation");
|
||||
active_conversation = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Delete a named conversation
|
||||
bool DeleteConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
auto it = conversation_cache.find(name);
|
||||
|
||||
if (it == conversation_cache.end()) {
|
||||
LogDebug("Conversation not found for deletion");
|
||||
return false;
|
||||
}
|
||||
if (name == "default") {
|
||||
LogDebug("Cannot delete the default conversation");
|
||||
return false;
|
||||
}
|
||||
if (it->second == active_conversation) {
|
||||
LogDebug("Cannot delete the currently active conversation");
|
||||
return false;
|
||||
}
|
||||
|
||||
LogDebug("Deleting conversation");
|
||||
conversation_cache.erase(it);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if a named conversation exists
|
||||
bool HasConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
return conversation_cache.count(name);
|
||||
}
|
||||
|
||||
private:
|
||||
// Internal implementation shared by Generate and GenerateMultimodal
|
||||
int GenerateInternal(const char* prompt_string,
|
||||
const void* image_data, // Null for text-only generation
|
||||
int image_width, // Added dimension (0 if no image)
|
||||
int image_height, // Added dimension (0 if no image)
|
||||
char* output, int max_length,
|
||||
GemmaTokenCallback callback, void* user_data);
|
||||
|
||||
// Pointer to the currently active conversation's data
|
||||
std::shared_ptr<ConversationData> active_conversation;
|
||||
|
||||
// Cache of all named conversations
|
||||
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||
conversation_cache;
|
||||
|
||||
// Buffers (potentially could be moved into ConversationData if needed
|
||||
// per-conversation)
|
||||
std::string prompt_buffer;
|
||||
std::string result_buffer;
|
||||
std::vector<int> token_buffer;
|
||||
|
||||
// Cached args (remain global for the context)
|
||||
InferenceArgs inference_args;
|
||||
ThreadingArgs threading_args;
|
||||
MatMulEnv matmul_env;
|
||||
|
||||
// Model itself (don't move this, needs to be below the args above)
|
||||
Gemma model;
|
||||
|
||||
// Random generator (remains global for the context)
|
||||
std::mt19937 gen;
|
||||
|
||||
// Static members for logging
|
||||
static GemmaLogCallback s_log_callback;
|
||||
static void* s_log_user_data;
|
||||
|
||||
// Use logging helper method to print messages into a managed callback if
|
||||
// necessary
|
||||
static void LogDebug(const char* message) {
|
||||
if (s_log_callback) {
|
||||
s_log_callback(message, s_log_user_data);
|
||||
} else {
|
||||
#ifdef _WIN32
|
||||
OutputDebugStringA(message);
|
||||
#else
|
||||
printf("%s", message);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
|
|
@ -271,12 +271,6 @@ class Gemma {
|
|||
ModelWeightsStorage model_;
|
||||
};
|
||||
|
||||
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
|
||||
// and `pos`, the number of tokens decoded so far; returns the corresponding
|
||||
// tokens. Asserts that tokenization is successful.
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt);
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, size_t prompt_size);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
|
|||
Loading…
Reference in New Issue