mirror of https://github.com/google/gemma.cpp.git
Adds:
- GemmaContext class that exposes Gemma functionality - C API that uses GemmaContext - C# interop class in GemmaInterop.cs - New END_OF_TURN_ID in tokenizer.h, useful when dealing with instruction-tuned prompts PiperOrigin-RevId: 730754638
This commit is contained in:
parent
b3b4b9f92f
commit
1f916b686b
38
BUILD.bazel
38
BUILD.bazel
|
|
@ -356,6 +356,44 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gemma_shared_lib",
|
||||||
|
srcs = [
|
||||||
|
"gemma/c_api.cc",
|
||||||
|
"gemma/context.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"gemma/activations.h",
|
||||||
|
"gemma/c_api.h",
|
||||||
|
"gemma/context.h",
|
||||||
|
"gemma/gemma.h",
|
||||||
|
],
|
||||||
|
exec_properties = {
|
||||||
|
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||||
|
"mem": "28g",
|
||||||
|
},
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":app",
|
||||||
|
":basics",
|
||||||
|
":common",
|
||||||
|
":kv_cache",
|
||||||
|
":ops",
|
||||||
|
":threading",
|
||||||
|
":tokenizer",
|
||||||
|
":weights",
|
||||||
|
"//compression:compress",
|
||||||
|
"//compression:io",
|
||||||
|
"//compression:sfp",
|
||||||
|
"//paligemma:image",
|
||||||
|
"@highway//:bit_set",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:nanobenchmark", # timer
|
||||||
|
"@highway//:profiler",
|
||||||
|
"@highway//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cross_entropy",
|
name = "cross_entropy",
|
||||||
srcs = ["evals/cross_entropy.cc"],
|
srcs = ["evals/cross_entropy.cc"],
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,17 @@ set(SOURCES
|
||||||
util/threading.h
|
util/threading.h
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add C API sources only when building DLL
|
||||||
|
if(BUILD_GEMMA_DLL)
|
||||||
|
list(APPEND SOURCES
|
||||||
|
gemma/context.h
|
||||||
|
gemma/context.cc
|
||||||
|
gemma/c_api.h
|
||||||
|
gemma/c_api.cc
|
||||||
|
)
|
||||||
|
message(STATUS "Including C API files for DLL build")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
set(CMAKE_BUILD_TYPE "Release")
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
endif()
|
endif()
|
||||||
|
|
@ -129,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
|
||||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||||
install(TARGETS libgemma DESTINATION lib)
|
install(TARGETS libgemma DESTINATION lib)
|
||||||
|
|
||||||
|
# Shared Library Target for C# interop
|
||||||
|
if(BUILD_GEMMA_DLL)
|
||||||
|
add_library(gemma_shared SHARED ${SOURCES})
|
||||||
|
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
|
||||||
|
set_target_properties(gemma_shared PROPERTIES
|
||||||
|
PREFIX ""
|
||||||
|
OUTPUT_NAME "gemma"
|
||||||
|
)
|
||||||
|
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_include_directories(gemma_shared PUBLIC ./)
|
||||||
|
target_link_libraries(gemma_shared PRIVATE
|
||||||
|
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
|
||||||
|
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
|
||||||
|
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
|
||||||
|
)
|
||||||
|
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||||
|
target_compile_definitions(gemma_shared
|
||||||
|
PRIVATE
|
||||||
|
GEMMA_EXPORTS
|
||||||
|
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
|
||||||
|
)
|
||||||
|
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||||
|
install(TARGETS gemma_shared DESTINATION lib)
|
||||||
|
install(FILES gemma/c_api.h DESTINATION include/gemma)
|
||||||
|
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Executable Target
|
# Executable Target
|
||||||
|
|
||||||
add_executable(gemma gemma/run.cc)
|
add_executable(gemma gemma/run.cc)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
using System;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Runtime.InteropServices;
|
||||||
|
using System.Text;
|
||||||
|
namespace GemmaCpp
|
||||||
|
{
|
||||||
|
public class GemmaException : Exception
|
||||||
|
{
|
||||||
|
public GemmaException(string message) : base(message) { }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Gemma : IDisposable
|
||||||
|
{
|
||||||
|
private IntPtr _context;
|
||||||
|
private bool _disposed;
|
||||||
|
|
||||||
|
// Optional: Allow setting DLL path
|
||||||
|
public static string DllPath { get; set; } = "gemma.dll";
|
||||||
|
|
||||||
|
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
|
||||||
|
private static extern IntPtr LoadLibrary(string lpFileName);
|
||||||
|
|
||||||
|
static Gemma()
|
||||||
|
{
|
||||||
|
// Load DLL from specified path
|
||||||
|
if (LoadLibrary(DllPath) == IntPtr.Zero)
|
||||||
|
{
|
||||||
|
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
private static extern IntPtr GemmaCreate(
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
||||||
|
int maxLength);
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
private static extern void GemmaDestroy(IntPtr context);
|
||||||
|
|
||||||
|
// Delegate type for token callbacks
|
||||||
|
public delegate bool TokenCallback(string token);
|
||||||
|
|
||||||
|
// Keep delegate alive for duration of calls
|
||||||
|
private GCHandle _callbackHandle;
|
||||||
|
|
||||||
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||||
|
private delegate bool GemmaTokenCallback(
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
|
||||||
|
IntPtr userData);
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
private static extern int GemmaGenerate(
|
||||||
|
IntPtr context,
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output,
|
||||||
|
int maxLength,
|
||||||
|
GemmaTokenCallback callback,
|
||||||
|
IntPtr userData);
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
private static extern int GemmaCountTokens(
|
||||||
|
IntPtr context,
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
|
||||||
|
|
||||||
|
// Native callback delegate type
|
||||||
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||||
|
private delegate void GemmaLogCallback(
|
||||||
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
|
||||||
|
IntPtr userData);
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
private static extern void GemmaSetLogCallback(
|
||||||
|
IntPtr context,
|
||||||
|
GemmaLogCallback callback,
|
||||||
|
IntPtr userData);
|
||||||
|
|
||||||
|
private GCHandle _logCallbackHandle;
|
||||||
|
|
||||||
|
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
|
||||||
|
{
|
||||||
|
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
|
||||||
|
if (_context == IntPtr.Zero)
|
||||||
|
{
|
||||||
|
throw new GemmaException("Failed to create Gemma context");
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionally: set up logging
|
||||||
|
/*
|
||||||
|
GemmaLogCallback logCallback = (message, _) =>
|
||||||
|
{
|
||||||
|
#if UNITY_ENGINE
|
||||||
|
Debug.Log($"Gemma: {message}");
|
||||||
|
#else
|
||||||
|
Debug.WriteLine($"Gemma: {message}");
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
_logCallbackHandle = GCHandle.Alloc(logCallback);
|
||||||
|
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
public int CountTokens(string prompt)
|
||||||
|
{
|
||||||
|
if (_disposed)
|
||||||
|
throw new ObjectDisposedException(nameof(Gemma));
|
||||||
|
|
||||||
|
if (_context == IntPtr.Zero)
|
||||||
|
throw new GemmaException("Gemma context is invalid");
|
||||||
|
int count = GemmaCountTokens(_context, prompt);
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
public string Generate(string prompt, int maxLength = 4096)
|
||||||
|
{
|
||||||
|
return Generate(prompt, null, maxLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
|
||||||
|
{
|
||||||
|
if (_disposed)
|
||||||
|
throw new ObjectDisposedException(nameof(Gemma));
|
||||||
|
|
||||||
|
if (_context == IntPtr.Zero)
|
||||||
|
throw new GemmaException("Gemma context is invalid");
|
||||||
|
|
||||||
|
var output = new StringBuilder(maxLength);
|
||||||
|
GemmaTokenCallback nativeCallback = null;
|
||||||
|
|
||||||
|
if (callback != null)
|
||||||
|
{
|
||||||
|
nativeCallback = (text, _) => callback(text);
|
||||||
|
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
||||||
|
}
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
int length = GemmaGenerate(_context, prompt, output, maxLength,
|
||||||
|
nativeCallback, IntPtr.Zero);
|
||||||
|
|
||||||
|
if (length < 0)
|
||||||
|
throw new GemmaException("Generation failed");
|
||||||
|
|
||||||
|
return output.ToString();
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
if (_callbackHandle.IsAllocated)
|
||||||
|
_callbackHandle.Free();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
if (!_disposed)
|
||||||
|
{
|
||||||
|
if (_context != IntPtr.Zero)
|
||||||
|
{
|
||||||
|
GemmaDestroy(_context);
|
||||||
|
_context = IntPtr.Zero;
|
||||||
|
}
|
||||||
|
if (_logCallbackHandle.IsAllocated)
|
||||||
|
_logCallbackHandle.Free();
|
||||||
|
_disposed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~Gemma()
|
||||||
|
{
|
||||||
|
Dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
#ifndef GEMMA_EXPORTS
|
||||||
|
#define GEMMA_EXPORTS
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "gemma/c_api.h"
|
||||||
|
|
||||||
|
// necessary as the C API and GemmaContext effectively wrap up and re-use the
|
||||||
|
// code for the Gemma executable
|
||||||
|
#include "util/app.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||||
|
const char* model_type,
|
||||||
|
const char* weights_path,
|
||||||
|
const char* weight_type, int max_length) {
|
||||||
|
try {
|
||||||
|
// kludge
|
||||||
|
gcpp::AppArgs app_args;
|
||||||
|
app_args.Init();
|
||||||
|
app_args.max_packages = 1;
|
||||||
|
app_args.verbosity = 0;
|
||||||
|
app_args.spin = gcpp::Tristate::kFalse;
|
||||||
|
|
||||||
|
return new GemmaContext(tokenizer_path, model_type, weights_path,
|
||||||
|
weight_type, app_args, max_length);
|
||||||
|
} catch (...) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
|
||||||
|
delete static_cast<gcpp::GemmaContext*>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||||
|
int max_length, GemmaTokenCallback callback,
|
||||||
|
void* user_data) {
|
||||||
|
if (!ctx) return -1;
|
||||||
|
return static_cast<gcpp::GemmaContext*>(ctx)->Generate(
|
||||||
|
prompt, output, max_length, callback, user_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
|
||||||
|
if (!ctx || !text) return -1;
|
||||||
|
return static_cast<gcpp::GemmaContext*>(ctx)->CountTokens(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
|
||||||
|
void* user_data) {
|
||||||
|
if (!ctx) return;
|
||||||
|
static_cast<gcpp::GemmaContext*>(ctx)->SetLogCallback(callback, user_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_C_API_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_C_API_H_
|
||||||
|
|
||||||
|
#include "gemma/context.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#ifdef GEMMA_EXPORTS
|
||||||
|
#define GEMMA_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define GEMMA_API __declspec(dllimport)
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#define GEMMA_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
typedef gcpp::GemmaContext GemmaContext;
|
||||||
|
#else
|
||||||
|
typedef struct GemmaContext GemmaContext;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||||
|
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||||
|
|
||||||
|
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||||
|
const char* model_type,
|
||||||
|
const char* weights_path,
|
||||||
|
const char* weight_type, int max_length);
|
||||||
|
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
|
||||||
|
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||||
|
int max_length, GemmaTokenCallback callback,
|
||||||
|
void* user_data);
|
||||||
|
|
||||||
|
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
|
||||||
|
|
||||||
|
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
|
||||||
|
void* user_data);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_C_API_H_
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
#include "gemma/context.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
void InitializeGemmaLibrary() {
|
||||||
|
AppArgs app;
|
||||||
|
app.Init();
|
||||||
|
app.max_packages = 1;
|
||||||
|
NestedPools pools = CreatePools(app);
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize static members
|
||||||
|
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
||||||
|
void* GemmaContext::s_log_user_data = nullptr;
|
||||||
|
|
||||||
|
GemmaContext::GemmaContext(const char* tokenizer_path, const char* model_type,
|
||||||
|
const char* weights_path, const char* weight_type,
|
||||||
|
const AppArgs& app_args, int max_length)
|
||||||
|
: pools(CreatePools(app_args)) {
|
||||||
|
LoaderArgs loader(tokenizer_path, weights_path, model_type);
|
||||||
|
loader.weight_type_str = weight_type;
|
||||||
|
|
||||||
|
if (const char* error = loader.Validate()) {
|
||||||
|
HWY_ABORT("Invalid loader configuration: %s", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize cached args
|
||||||
|
inference_args.Init();
|
||||||
|
inference_args.max_generated_tokens = max_length;
|
||||||
|
inference_args.temperature = 0.7f;
|
||||||
|
inference_args.top_k = 1;
|
||||||
|
inference_args.deterministic = false;
|
||||||
|
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
|
model = AllocateGemma(loader, pools);
|
||||||
|
kv_cache =
|
||||||
|
std::make_unique<KVCache>(KVCache::Create(model->GetModelConfig(), 2048));
|
||||||
|
}
|
||||||
|
|
||||||
|
int GemmaContext::Generate(const char* prompt, char* output, int max_length,
|
||||||
|
GemmaTokenCallback callback, void* user_data) {
|
||||||
|
if (!model || !kv_cache || !prompt || !output || max_length <= 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Clear and reuse buffers
|
||||||
|
result_buffer.clear();
|
||||||
|
prompt_buffer.assign(prompt);
|
||||||
|
token_buffer.clear();
|
||||||
|
|
||||||
|
// The prompt is assumed to be already wrapped in the appropriate control
|
||||||
|
// tokens if necessary for an instruction tuned model, so we don't use
|
||||||
|
// WrapAndTokenize here
|
||||||
|
HWY_ASSERT(model->Tokenizer().Encode(prompt, &token_buffer));
|
||||||
|
|
||||||
|
// Both pre-trained and instruction-tuned require BOS as first token
|
||||||
|
if (token_buffer.at(0) != BOS_ID) {
|
||||||
|
token_buffer.insert(token_buffer.begin(), BOS_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass prompt_tokens to properly utilize KV cache for subsequent tokens
|
||||||
|
const size_t prompt_tokens = token_buffer.size();
|
||||||
|
size_t tokens_generated_this_turn = 0;
|
||||||
|
|
||||||
|
auto stream_token = [this, callback, user_data, prompt_tokens,
|
||||||
|
&tokens_generated_this_turn](int token, float) {
|
||||||
|
std::string token_text;
|
||||||
|
if (model->Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
|
||||||
|
// don't re-output the prompt tokens
|
||||||
|
if (tokens_generated_this_turn < prompt_tokens) {
|
||||||
|
++tokens_generated_this_turn;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// skip the end of turn token, this way we don't have to do string
|
||||||
|
// comparisons at the application level (is this a good idea?)
|
||||||
|
if (token == END_OF_TURN_ID) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (callback) {
|
||||||
|
if (!callback(token_text.c_str(), user_data)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result_buffer.append(token_text);
|
||||||
|
++tokens_generated_this_turn;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
|
.verbosity = 0,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.use_spinning = Tristate::kFalse};
|
||||||
|
inference_args.max_generated_tokens = max_length;
|
||||||
|
inference_args.CopyTo(runtime_config);
|
||||||
|
|
||||||
|
TimingInfo timing_info = {.verbosity = 0};
|
||||||
|
hwy::Span<const int> testspan(token_buffer.data(), token_buffer.size());
|
||||||
|
|
||||||
|
// Pass prompt_tokens to properly utilize KV cache for subsequent tokens
|
||||||
|
model->Generate(runtime_config, testspan, prompt_tokens, 0, *kv_cache,
|
||||||
|
timing_info);
|
||||||
|
|
||||||
|
if (result_buffer.length() >= static_cast<size_t>(max_length)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
strcpy(output, result_buffer.c_str());
|
||||||
|
return static_cast<int>(result_buffer.length());
|
||||||
|
} catch (...) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int GemmaContext::CountTokens(const char* text) {
|
||||||
|
if (!model || !text) return -1;
|
||||||
|
try {
|
||||||
|
std::string text_str(text);
|
||||||
|
std::vector<int> tokens;
|
||||||
|
HWY_ASSERT(model->Tokenizer().Encode(text_str, &tokens));
|
||||||
|
return static_cast<int>(tokens.size());
|
||||||
|
} catch (...) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <Windows.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/app.h"
|
||||||
|
#include "util/threading.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Initialize global state needed by the library.
|
||||||
|
// Must be called before creating any Gemma instances.
|
||||||
|
void InitializeGemmaLibrary();
|
||||||
|
|
||||||
|
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||||
|
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||||
|
|
||||||
|
class GemmaContext {
|
||||||
|
public:
|
||||||
|
GemmaContext(const char* tokenizer_path, const char* model_type,
|
||||||
|
const char* weights_path, const char* weight_type,
|
||||||
|
const AppArgs& app_args, int max_length = 2048);
|
||||||
|
|
||||||
|
// Returns length of generated text, or -1 on error
|
||||||
|
int Generate(const char* prompt, char* output, int max_length,
|
||||||
|
GemmaTokenCallback callback, void* user_data);
|
||||||
|
|
||||||
|
// Returns number of tokens in text, or -1 on error
|
||||||
|
int CountTokens(const char* text);
|
||||||
|
|
||||||
|
// Add new method to set logger
|
||||||
|
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
|
||||||
|
s_log_callback = callback;
|
||||||
|
s_log_user_data = user_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
NestedPools pools;
|
||||||
|
std::unique_ptr<Gemma> model;
|
||||||
|
std::unique_ptr<KVCache> kv_cache;
|
||||||
|
std::string prompt_buffer;
|
||||||
|
std::string result_buffer;
|
||||||
|
std::vector<int> token_buffer;
|
||||||
|
|
||||||
|
// Cached args
|
||||||
|
InferenceArgs inference_args;
|
||||||
|
AppArgs app_args;
|
||||||
|
std::mt19937 gen;
|
||||||
|
|
||||||
|
// Add static members for logging
|
||||||
|
static GemmaLogCallback s_log_callback;
|
||||||
|
static void* s_log_user_data;
|
||||||
|
|
||||||
|
// Use logging helper method to print messages into a managed callback if
|
||||||
|
// necessary
|
||||||
|
static void LogDebug(const char* message) {
|
||||||
|
if (s_log_callback) {
|
||||||
|
s_log_callback(message, s_log_user_data);
|
||||||
|
} else {
|
||||||
|
#ifdef _WIN32
|
||||||
|
OutputDebugStringA(message);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||||
|
|
@ -31,6 +31,9 @@ namespace gcpp {
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
constexpr int BOS_ID = 2;
|
constexpr int BOS_ID = 2;
|
||||||
|
|
||||||
|
// The tokenizer's end of turn token id.
|
||||||
|
constexpr int END_OF_TURN_ID = 107;
|
||||||
|
|
||||||
class GemmaTokenizer {
|
class GemmaTokenizer {
|
||||||
public:
|
public:
|
||||||
GemmaTokenizer();
|
GemmaTokenizer();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue