mirror of https://github.com/google/gemma.cpp.git
cleanup, new conversation methods, bugfixes
- chore: unused parameters cleaned up - bugfix: explicitly use hwy::Span in GenerateInternal() to prevent runtime crashes due to memory layout incompatibility - bugfix: explicit nullptr check in LogDebug - chore: length-related parameters renamed for clarity - feature: SaveConversation() can be optionally used to save copy of a conversation that ResetConversation() will rewind to upon request, rather than just an empty KV cache - feature: GetCurrentConversation() can be used to query the current conversation's name PiperOrigin-RevId: 755873147
This commit is contained in:
parent
e9ecb7794d
commit
20757046db
|
|
@ -1,3 +1,18 @@
|
||||||
|
// Copyright 2025 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using System.Runtime.InteropServices;
|
using System.Runtime.InteropServices;
|
||||||
|
|
@ -35,7 +50,7 @@ namespace GemmaCpp
|
||||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
||||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
||||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
||||||
int maxLength);
|
int maxGeneratedTokens);
|
||||||
|
|
||||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||||
private static extern void GemmaDestroy(IntPtr context);
|
private static extern void GemmaDestroy(IntPtr context);
|
||||||
|
|
@ -56,7 +71,7 @@ namespace GemmaCpp
|
||||||
IntPtr context,
|
IntPtr context,
|
||||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||||
[Out] byte[] output,
|
[Out] byte[] output,
|
||||||
int maxLength,
|
int maxOutputChars,
|
||||||
GemmaTokenCallback callback,
|
GemmaTokenCallback callback,
|
||||||
IntPtr userData);
|
IntPtr userData);
|
||||||
|
|
||||||
|
|
@ -68,7 +83,7 @@ namespace GemmaCpp
|
||||||
int image_width, // Added dimension
|
int image_width, // Added dimension
|
||||||
int image_height, // Added dimension
|
int image_height, // Added dimension
|
||||||
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
|
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
|
||||||
int maxLength,
|
int maxOutputChars,
|
||||||
GemmaTokenCallback callback,
|
GemmaTokenCallback callback,
|
||||||
IntPtr userData);
|
IntPtr userData);
|
||||||
|
|
||||||
|
|
@ -120,6 +135,13 @@ namespace GemmaCpp
|
||||||
IntPtr context,
|
IntPtr context,
|
||||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")]
|
||||||
|
[return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string
|
||||||
|
private static extern string GemmaGetCurrentConversation(IntPtr context);
|
||||||
|
|
||||||
|
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")]
|
||||||
|
private static extern void GemmaSaveConversation(IntPtr context);
|
||||||
|
|
||||||
// Native callback delegate type
|
// Native callback delegate type
|
||||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||||
private delegate void GemmaLogCallback(
|
private delegate void GemmaLogCallback(
|
||||||
|
|
@ -135,9 +157,9 @@ namespace GemmaCpp
|
||||||
private GCHandle _logCallbackHandle;
|
private GCHandle _logCallbackHandle;
|
||||||
private bool _loggingEnabled = false;
|
private bool _loggingEnabled = false;
|
||||||
|
|
||||||
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
|
public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192)
|
||||||
{
|
{
|
||||||
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
|
_context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens);
|
||||||
if (_context == IntPtr.Zero)
|
if (_context == IntPtr.Zero)
|
||||||
{
|
{
|
||||||
throw new GemmaException("Failed to create Gemma context");
|
throw new GemmaException("Failed to create Gemma context");
|
||||||
|
|
@ -281,6 +303,31 @@ namespace GemmaCpp
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public string GetCurrentConversation()
|
||||||
|
{
|
||||||
|
if (_disposed)
|
||||||
|
throw new ObjectDisposedException(nameof(Gemma));
|
||||||
|
|
||||||
|
if (_context == IntPtr.Zero)
|
||||||
|
throw new GemmaException("Gemma context is invalid");
|
||||||
|
|
||||||
|
string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method
|
||||||
|
Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'");
|
||||||
|
return currentConversation;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void SaveConversation()
|
||||||
|
{
|
||||||
|
if (_disposed)
|
||||||
|
throw new ObjectDisposedException(nameof(Gemma));
|
||||||
|
|
||||||
|
if (_context == IntPtr.Zero)
|
||||||
|
throw new GemmaException("Gemma context is invalid");
|
||||||
|
|
||||||
|
GemmaSaveConversation(_context);
|
||||||
|
Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache.");
|
||||||
|
}
|
||||||
|
|
||||||
public int CountTokens(string prompt)
|
public int CountTokens(string prompt)
|
||||||
{
|
{
|
||||||
if (_disposed)
|
if (_disposed)
|
||||||
|
|
@ -292,12 +339,12 @@ namespace GemmaCpp
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
public string Generate(string prompt, int maxLength = 4096)
|
public string Generate(string prompt, int maxOutputChars = 4096)
|
||||||
{
|
{
|
||||||
return Generate(prompt, null, maxLength);
|
return Generate(prompt, null, maxOutputChars);
|
||||||
}
|
}
|
||||||
|
|
||||||
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
|
public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096)
|
||||||
{
|
{
|
||||||
if (_disposed)
|
if (_disposed)
|
||||||
throw new ObjectDisposedException(nameof(Gemma));
|
throw new ObjectDisposedException(nameof(Gemma));
|
||||||
|
|
@ -305,7 +352,7 @@ namespace GemmaCpp
|
||||||
if (_context == IntPtr.Zero)
|
if (_context == IntPtr.Zero)
|
||||||
throw new GemmaException("Gemma context is invalid");
|
throw new GemmaException("Gemma context is invalid");
|
||||||
|
|
||||||
var outputBuffer = new byte[maxLength * 4]; // Allow for worst case UTF-8 size
|
var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size
|
||||||
GemmaTokenCallback nativeCallback = null;
|
GemmaTokenCallback nativeCallback = null;
|
||||||
|
|
||||||
// Track token count for debugging
|
// Track token count for debugging
|
||||||
|
|
@ -327,7 +374,7 @@ namespace GemmaCpp
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
int length = GemmaGenerate(_context, prompt, outputBuffer, maxLength,
|
int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars,
|
||||||
nativeCallback, IntPtr.Zero);
|
nativeCallback, IntPtr.Zero);
|
||||||
|
|
||||||
if (length < 0)
|
if (length < 0)
|
||||||
|
|
@ -346,13 +393,13 @@ namespace GemmaCpp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxLength = 4096)
|
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096)
|
||||||
{
|
{
|
||||||
// Pass width and height to the overloaded method
|
// Pass width and height to the overloaded method
|
||||||
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxLength);
|
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars);
|
||||||
}
|
}
|
||||||
|
|
||||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxLength = 4096)
|
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096)
|
||||||
{
|
{
|
||||||
if (_disposed)
|
if (_disposed)
|
||||||
throw new ObjectDisposedException(nameof(Gemma));
|
throw new ObjectDisposedException(nameof(Gemma));
|
||||||
|
|
@ -369,7 +416,7 @@ namespace GemmaCpp
|
||||||
if (imageData.Length < imageWidth * imageHeight * 3)
|
if (imageData.Length < imageWidth * imageHeight * 3)
|
||||||
throw new ArgumentException("Image data array is too small for the specified dimensions");
|
throw new ArgumentException("Image data array is too small for the specified dimensions");
|
||||||
|
|
||||||
var output = new StringBuilder(maxLength);
|
var output = new StringBuilder(maxOutputChars);
|
||||||
GemmaTokenCallback nativeCallback = null;
|
GemmaTokenCallback nativeCallback = null;
|
||||||
|
|
||||||
if (callback != null)
|
if (callback != null)
|
||||||
|
|
@ -386,7 +433,7 @@ namespace GemmaCpp
|
||||||
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
|
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
|
||||||
|
|
||||||
// Pass image dimensions to the native call
|
// Pass image dimensions to the native call
|
||||||
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxLength,
|
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars,
|
||||||
nativeCallback, IntPtr.Zero);
|
nativeCallback, IntPtr.Zero);
|
||||||
|
|
||||||
if (length < 0)
|
if (length < 0)
|
||||||
|
|
|
||||||
|
|
@ -22,37 +22,38 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||||
const char* model_type,
|
|
||||||
const char* weights_path,
|
const char* weights_path,
|
||||||
const char* weight_type, int max_length) {
|
int max_generated_tokens) {
|
||||||
try {
|
try {
|
||||||
GemmaContext* ctx = GemmaContext::Create(
|
GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path,
|
||||||
tokenizer_path, model_type, weights_path, weight_type, max_length);
|
max_generated_tokens);
|
||||||
return ctx;
|
return ctx;
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx) { delete ctx; }
|
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
|
||||||
|
delete ctx;
|
||||||
|
}
|
||||||
|
|
||||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||||
int max_length, GemmaTokenCallback callback,
|
int max_output_chars, GemmaTokenCallback callback,
|
||||||
void* user_data) {
|
void* user_data) {
|
||||||
if (!ctx) return -1;
|
if (!ctx) return -1;
|
||||||
return ctx->Generate(prompt, output, max_length, callback, user_data);
|
return ctx->Generate(prompt, output, max_output_chars, callback, user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||||
const void* image_data, int image_width,
|
const void* image_data, int image_width,
|
||||||
int image_height, char* output,
|
int image_height, char* output,
|
||||||
int max_length,
|
int max_output_chars,
|
||||||
GemmaTokenCallback callback,
|
GemmaTokenCallback callback,
|
||||||
void* user_data) {
|
void* user_data) {
|
||||||
if (!ctx) return -1;
|
if (!ctx) return -1;
|
||||||
|
|
||||||
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
|
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
|
||||||
output, max_length, callback, user_data);
|
output, max_output_chars, callback, user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
|
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
|
||||||
|
|
@ -125,4 +126,14 @@ GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||||
if (!ctx || !conversation_name) return 0;
|
if (!ctx || !conversation_name) return 0;
|
||||||
return ctx->HasConversation(conversation_name) ? 1 : 0;
|
return ctx->HasConversation(conversation_name) ? 1 : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) {
|
||||||
|
if (!ctx) return nullptr;
|
||||||
|
return ctx->GetCurrentConversation();
|
||||||
|
}
|
||||||
|
|
||||||
|
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) {
|
||||||
|
if (!ctx) return;
|
||||||
|
ctx->SaveConversation();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,18 +42,16 @@ typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||||
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||||
|
|
||||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||||
const char* model_type,
|
|
||||||
const char* weights_path,
|
const char* weights_path,
|
||||||
const char* weight_type, int max_length);
|
int max_generated_tokens);
|
||||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
|
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
|
||||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||||
int max_length, GemmaTokenCallback callback,
|
int max_output_chars, GemmaTokenCallback callback,
|
||||||
void* user_data);
|
void* user_data);
|
||||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||||
const void* image_data, // Renamed param
|
const void* image_data, int image_width,
|
||||||
int image_width, // Added dimension
|
int image_height, char* output,
|
||||||
int image_height, // Added dimension
|
int max_output_chars,
|
||||||
char* output, int max_length,
|
|
||||||
GemmaTokenCallback callback,
|
GemmaTokenCallback callback,
|
||||||
void* user_data);
|
void* user_data);
|
||||||
|
|
||||||
|
|
@ -67,17 +65,19 @@ GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
|
||||||
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
|
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
|
||||||
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
|
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
|
||||||
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
|
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
|
||||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Renamed
|
GEMMA_API void GemmaResetConversation(GemmaContext* ctx);
|
||||||
|
|
||||||
// Conversation management functions (renamed)
|
// Conversation management functions (renamed)
|
||||||
GEMMA_API int GemmaCreateConversation(
|
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
|
||||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
const char* conversation_name);
|
||||||
GEMMA_API int GemmaSwitchConversation(
|
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
|
||||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
const char* conversation_name);
|
||||||
GEMMA_API int GemmaDeleteConversation(
|
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
|
||||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
const char* conversation_name);
|
||||||
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||||
const char* conversation_name); // Renamed
|
const char* conversation_name);
|
||||||
|
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx);
|
||||||
|
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,23 +44,36 @@ namespace gcpp {
|
||||||
// ConversationData constructor implementation
|
// ConversationData constructor implementation
|
||||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||||
size_t prefill_tbatch_size)
|
size_t prefill_tbatch_size)
|
||||||
: kv_cache(std::make_unique<KVCache>(
|
: model_config_ref_(model_config),
|
||||||
|
prefill_tbatch_size_(prefill_tbatch_size),
|
||||||
|
kv_cache(std::make_unique<KVCache>(
|
||||||
KVCache::Create(model_config, prefill_tbatch_size))),
|
KVCache::Create(model_config, prefill_tbatch_size))),
|
||||||
abs_pos(0) {}
|
abs_pos(0) {}
|
||||||
|
|
||||||
|
// ConversationData copy constructor implementation
|
||||||
|
ConversationData::ConversationData(const ConversationData& other)
|
||||||
|
: model_config_ref_(other.model_config_ref_),
|
||||||
|
prefill_tbatch_size_(other.prefill_tbatch_size_),
|
||||||
|
kv_cache(nullptr),
|
||||||
|
abs_pos(other.abs_pos) {
|
||||||
|
if (other.kv_cache) {
|
||||||
|
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy(
|
||||||
|
other.model_config_ref_, other.prefill_tbatch_size_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize static members
|
// Initialize static members
|
||||||
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
||||||
void* GemmaContext::s_log_user_data = nullptr;
|
void* GemmaContext::s_log_user_data = nullptr;
|
||||||
|
|
||||||
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||||
const char* ignored1,
|
|
||||||
const char* weights_path,
|
const char* weights_path,
|
||||||
const char* ignored2, int max_length) {
|
int max_generated_tokens) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Creating GemmaContext with tokenizer_path: "
|
ss << "Creating GemmaContext with tokenizer_path: "
|
||||||
<< (tokenizer_path ? tokenizer_path : "null")
|
<< (tokenizer_path ? tokenizer_path : "null")
|
||||||
<< ", weights_path: " << (weights_path ? weights_path : "null")
|
<< ", weights_path: " << (weights_path ? weights_path : "null")
|
||||||
<< ", max_length: " << max_length;
|
<< ", max_generated_tokens: " << max_generated_tokens;
|
||||||
LogDebug(ss.str().c_str());
|
LogDebug(ss.str().c_str());
|
||||||
|
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
|
|
@ -73,27 +86,30 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||||
LogDebug("Initializing inference args");
|
LogDebug("Initializing inference args");
|
||||||
InferenceArgs inference_args;
|
InferenceArgs inference_args;
|
||||||
inference_args.Init();
|
inference_args.Init();
|
||||||
inference_args.max_generated_tokens = max_length;
|
inference_args.max_generated_tokens = max_generated_tokens;
|
||||||
inference_args.temperature = 0.7f;
|
inference_args.temperature = 0.7f;
|
||||||
inference_args.top_k = 1;
|
inference_args.top_k = 1;
|
||||||
inference_args.deterministic = false;
|
inference_args.deterministic = false;
|
||||||
|
|
||||||
ss.str("");
|
ss.str("");
|
||||||
ss << "Inference args initialized with max_tokens: " << max_length
|
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
||||||
<< ", temperature: " << inference_args.temperature
|
<< ", temperature: " << inference_args.temperature
|
||||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
||||||
<< (inference_args.deterministic ? "true" : "false");
|
<< (inference_args.deterministic ? "true" : "false");
|
||||||
LogDebug(ss.str().c_str());
|
LogDebug(ss.str().c_str());
|
||||||
|
|
||||||
return new GemmaContext(loader, inference_args, threading_args, max_length);
|
return new GemmaContext(loader, inference_args, threading_args,
|
||||||
|
max_generated_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
const InferenceArgs& inference_args,
|
const InferenceArgs& inference_args,
|
||||||
const ThreadingArgs& threading_args, int max_length)
|
const ThreadingArgs& threading_args,
|
||||||
|
int max_generated_tokens)
|
||||||
: inference_args(inference_args),
|
: inference_args(inference_args),
|
||||||
threading_args(threading_args),
|
threading_args(threading_args),
|
||||||
matmul_env(MakeMatMulEnv(threading_args)),
|
matmul_env(MakeMatMulEnv(threading_args)),
|
||||||
|
active_conversation_name("default"),
|
||||||
model(loader, matmul_env) {
|
model(loader, matmul_env) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
|
||||||
|
|
@ -114,7 +130,8 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
int GemmaContext::GenerateInternal(const char* prompt_string,
|
int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
const void* image_data, int image_width,
|
const void* image_data, int image_width,
|
||||||
int image_height, char* output,
|
int image_height, char* output,
|
||||||
int max_length, GemmaTokenCallback callback,
|
int max_output_chars,
|
||||||
|
GemmaTokenCallback callback,
|
||||||
void* user_data) {
|
void* user_data) {
|
||||||
PROFILER_ZONE("Gen.Internal");
|
PROFILER_ZONE("Gen.Internal");
|
||||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||||
|
|
@ -224,8 +241,12 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a span from the prompt vector - Generate() expects a hwy::Span,
|
||||||
|
// which has a different memory footprint to that of a std::vector.
|
||||||
|
hwy::Span<const int> prompt_span(prompt.data(), prompt.size());
|
||||||
|
|
||||||
// Pass the KVCache object by reference from the active conversation
|
// Pass the KVCache object by reference from the active conversation
|
||||||
model.Generate(runtime_config, prompt, active_conversation->abs_pos,
|
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
|
||||||
prefix_end, *(active_conversation->kv_cache), timing_info);
|
prefix_end, *(active_conversation->kv_cache), timing_info);
|
||||||
|
|
||||||
// prepare for next turn
|
// prepare for next turn
|
||||||
|
|
@ -251,26 +272,26 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy result buffer to output C-string (ensure null termination)
|
// Copy result buffer to output C-string (ensure null termination)
|
||||||
strncpy(output, result_buffer.c_str(), max_length - 1);
|
strncpy(output, result_buffer.c_str(), max_output_chars - 1);
|
||||||
output[max_length - 1] = '\0'; // Explicit null termination
|
output[max_output_chars - 1] = '\0';
|
||||||
|
|
||||||
return static_cast<int>(strlen(output)); // Return length of the C-string
|
return static_cast<int>(strlen(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Public Generate method (wrapper for text-only)
|
// Public Generate method (wrapper for text-only)
|
||||||
int GemmaContext::Generate(const char* prompt_string, char* output,
|
int GemmaContext::Generate(const char* prompt_string, char* output,
|
||||||
int max_length, GemmaTokenCallback callback,
|
int max_output_chars, GemmaTokenCallback callback,
|
||||||
void* user_data) {
|
void* user_data) {
|
||||||
// Call the internal implementation with null image_data and 0 dimensions
|
// Call the internal implementation with null image_data and 0 dimensions
|
||||||
return GenerateInternal(prompt_string, nullptr, 0, 0, output, max_length,
|
return GenerateInternal(prompt_string, nullptr, 0, 0, output,
|
||||||
callback, user_data);
|
max_output_chars, callback, user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Public GenerateMultimodal method (wrapper)
|
// Public GenerateMultimodal method (wrapper)
|
||||||
int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
||||||
const void* image_data, int image_width,
|
const void* image_data, int image_width,
|
||||||
int image_height, // Added dimensions
|
int image_height, char* output,
|
||||||
char* output, int max_length,
|
int max_output_chars,
|
||||||
GemmaTokenCallback callback,
|
GemmaTokenCallback callback,
|
||||||
void* user_data) {
|
void* user_data) {
|
||||||
if (image_data == nullptr) {
|
if (image_data == nullptr) {
|
||||||
|
|
@ -283,7 +304,7 @@ int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
||||||
}
|
}
|
||||||
|
|
||||||
return GenerateInternal(prompt_string, image_data, image_width, image_height,
|
return GenerateInternal(prompt_string, image_data, image_width, image_height,
|
||||||
output, max_length, callback, user_data);
|
output, max_output_chars, callback, user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
int GemmaContext::CountTokens(const char* text) {
|
int GemmaContext::CountTokens(const char* text) {
|
||||||
|
|
@ -320,4 +341,9 @@ int GemmaContext::CountTokens(const char* text) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the name of the currently active conversation
|
||||||
|
const char* GemmaContext::GetCurrentConversation() {
|
||||||
|
return active_conversation_name.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,17 @@ struct KVCache;
|
||||||
|
|
||||||
// Struct to hold data for a single conversation thread
|
// Struct to hold data for a single conversation thread
|
||||||
struct ConversationData {
|
struct ConversationData {
|
||||||
|
public:
|
||||||
|
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
||||||
|
ConversationData(const ConversationData& other);
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ModelConfig& model_config_ref_;
|
||||||
|
size_t prefill_tbatch_size_;
|
||||||
|
|
||||||
|
public:
|
||||||
std::unique_ptr<KVCache> kv_cache;
|
std::unique_ptr<KVCache> kv_cache;
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
|
|
||||||
// Constructor to initialize kv_cache (requires KVCache definition or forward
|
|
||||||
// declaration)
|
|
||||||
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||||
|
|
@ -57,20 +62,20 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||||
class GemmaContext {
|
class GemmaContext {
|
||||||
private:
|
private:
|
||||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
||||||
const ThreadingArgs& threading_args, int max_length);
|
const ThreadingArgs& threading_args, int max_generated_tokens);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static GemmaContext* Create(const char* tokenizer_path, const char* ignored1,
|
static GemmaContext* Create(const char* tokenizer_path,
|
||||||
const char* weights_path, const char* ignored2,
|
const char* weights_path,
|
||||||
int max_length);
|
int max_generated_tokens);
|
||||||
|
|
||||||
// Returns length of generated text, or -1 on error
|
// Returns length of generated text, or -1 on error
|
||||||
int Generate(const char* prompt_string, char* output, int max_length,
|
int Generate(const char* prompt_string, char* output, int max_output_chars,
|
||||||
GemmaTokenCallback callback, void* user_data);
|
GemmaTokenCallback callback, void* user_data);
|
||||||
// Returns length of generated text, or -1 on error
|
// Returns length of generated text, or -1 on error
|
||||||
int GenerateMultimodal(const char* prompt_string, const void* image_data,
|
int GenerateMultimodal(const char* prompt_string, const void* image_data,
|
||||||
int image_width, int image_height, char* output,
|
int image_width, int image_height, char* output,
|
||||||
int max_length, GemmaTokenCallback callback,
|
int max_output_chars, GemmaTokenCallback callback,
|
||||||
void* user_data);
|
void* user_data);
|
||||||
|
|
||||||
// Returns number of tokens in text, or -1 on error
|
// Returns number of tokens in text, or -1 on error
|
||||||
|
|
@ -122,15 +127,71 @@ class GemmaContext {
|
||||||
LogDebug("Setting prefill_tbatch_size to configured value");
|
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SaveConversation() {
|
||||||
|
if (!active_conversation || active_conversation_name.empty()) {
|
||||||
|
if (!active_conversation) {
|
||||||
|
LogDebug("SaveConversation: No active conversation to save.");
|
||||||
|
} else { // active_conversation_name must be empty
|
||||||
|
LogDebug(
|
||||||
|
"SaveConversation: Active conversation name is empty. Cannot "
|
||||||
|
"save.");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string log_msg = "SaveConversation: Attempting to save '";
|
||||||
|
log_msg += active_conversation_name;
|
||||||
|
log_msg += "' to prewarmed_cache.";
|
||||||
|
LogDebug(log_msg.c_str());
|
||||||
|
|
||||||
|
// Create a deep copy of the active_conversation.
|
||||||
|
// The ConversationData copy constructor handles the deep copy of KVCache.
|
||||||
|
auto conversation_copy =
|
||||||
|
std::make_shared<ConversationData>(*active_conversation);
|
||||||
|
|
||||||
|
// Store the deep copy in prewarmed_cache.
|
||||||
|
// If a conversation with the same name already exists, it will be
|
||||||
|
// overwritten. std::shared_ptr will handle the destruction of the old
|
||||||
|
// object if it's being replaced.
|
||||||
|
prewarmed_cache[active_conversation_name] = conversation_copy;
|
||||||
|
|
||||||
|
log_msg = "SaveConversation: Successfully saved '";
|
||||||
|
log_msg += active_conversation_name;
|
||||||
|
log_msg += "' to prewarmed_cache.";
|
||||||
|
LogDebug(log_msg.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// Reset the currently active conversation
|
// Reset the currently active conversation
|
||||||
void ResetConversation() {
|
void ResetConversation() {
|
||||||
if (active_conversation) {
|
if (active_conversation) {
|
||||||
LogDebug("Resetting active conversation");
|
std::string log_prefix = "ResetConversation ('";
|
||||||
|
log_prefix += active_conversation_name.empty() ? "[unnamed]"
|
||||||
|
: active_conversation_name;
|
||||||
|
log_prefix += "'): ";
|
||||||
|
LogDebug((log_prefix + "Attempting to reset.").c_str());
|
||||||
|
// Attempt to restore from prewarmed_cache first, regardless of name.
|
||||||
|
auto it = prewarmed_cache.find(active_conversation_name);
|
||||||
|
if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) {
|
||||||
|
// Found in prewarmed_cache and the cached entry is valid.
|
||||||
|
LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.")
|
||||||
|
.c_str());
|
||||||
|
active_conversation->abs_pos = it->second->abs_pos;
|
||||||
|
// Perform a deep copy of the KVCache from the prewarmed version.
|
||||||
|
active_conversation->kv_cache =
|
||||||
|
std::make_unique<KVCache>(it->second->kv_cache->Copy(
|
||||||
|
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||||
|
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
|
||||||
|
.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not found in prewarmed_cache or prewarmed_cache entry is invalid,
|
||||||
|
// rewind to initial state.
|
||||||
active_conversation->abs_pos = 0;
|
active_conversation->abs_pos = 0;
|
||||||
// Replace the cache within the current ConversationData object
|
// Replace the cache within the current ConversationData object
|
||||||
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
|
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
|
||||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||||
LogDebug("Active conversation reset");
|
|
||||||
|
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||||
} else {
|
} else {
|
||||||
LogDebug("Cannot reset conversation: active_conversation is null");
|
LogDebug("Cannot reset conversation: active_conversation is null");
|
||||||
}
|
}
|
||||||
|
|
@ -160,6 +221,7 @@ class GemmaContext {
|
||||||
}
|
}
|
||||||
LogDebug("Switching active conversation");
|
LogDebug("Switching active conversation");
|
||||||
active_conversation = it->second;
|
active_conversation = it->second;
|
||||||
|
active_conversation_name = conversation_name;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -183,6 +245,12 @@ class GemmaContext {
|
||||||
|
|
||||||
LogDebug("Deleting conversation");
|
LogDebug("Deleting conversation");
|
||||||
conversation_cache.erase(it);
|
conversation_cache.erase(it);
|
||||||
|
|
||||||
|
auto it2 = prewarmed_cache.find(name);
|
||||||
|
if (it2 != prewarmed_cache.end()) {
|
||||||
|
prewarmed_cache.erase(it2);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -192,13 +260,16 @@ class GemmaContext {
|
||||||
return conversation_cache.count(name);
|
return conversation_cache.count(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the name of the currently active conversation
|
||||||
|
const char* GetCurrentConversation();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Internal implementation shared by Generate and GenerateMultimodal
|
// Internal implementation shared by Generate and GenerateMultimodal
|
||||||
int GenerateInternal(const char* prompt_string,
|
int GenerateInternal(const char* prompt_string,
|
||||||
const void* image_data, // Null for text-only generation
|
const void* image_data, // Null for text-only generation
|
||||||
int image_width, // Added dimension (0 if no image)
|
int image_width,
|
||||||
int image_height, // Added dimension (0 if no image)
|
int image_height,
|
||||||
char* output, int max_length,
|
char* output, int max_output_chars,
|
||||||
GemmaTokenCallback callback, void* user_data);
|
GemmaTokenCallback callback, void* user_data);
|
||||||
|
|
||||||
// Pointer to the currently active conversation's data
|
// Pointer to the currently active conversation's data
|
||||||
|
|
@ -207,6 +278,8 @@ class GemmaContext {
|
||||||
// Cache of all named conversations
|
// Cache of all named conversations
|
||||||
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||||
conversation_cache;
|
conversation_cache;
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||||
|
prewarmed_cache;
|
||||||
|
|
||||||
// Buffers (potentially could be moved into ConversationData if needed
|
// Buffers (potentially could be moved into ConversationData if needed
|
||||||
// per-conversation)
|
// per-conversation)
|
||||||
|
|
@ -219,6 +292,8 @@ class GemmaContext {
|
||||||
ThreadingArgs threading_args;
|
ThreadingArgs threading_args;
|
||||||
MatMulEnv matmul_env;
|
MatMulEnv matmul_env;
|
||||||
|
|
||||||
|
std::string active_conversation_name;
|
||||||
|
|
||||||
// Model itself (don't move this, needs to be below the args above)
|
// Model itself (don't move this, needs to be below the args above)
|
||||||
Gemma model;
|
Gemma model;
|
||||||
|
|
||||||
|
|
@ -232,7 +307,7 @@ class GemmaContext {
|
||||||
// Use logging helper method to print messages into a managed callback if
|
// Use logging helper method to print messages into a managed callback if
|
||||||
// necessary
|
// necessary
|
||||||
static void LogDebug(const char* message) {
|
static void LogDebug(const char* message) {
|
||||||
if (s_log_callback) {
|
if (s_log_callback != nullptr) {
|
||||||
s_log_callback(message, s_log_user_data);
|
s_log_callback(message, s_log_user_data);
|
||||||
} else {
|
} else {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue