cleanup, new conversation methods, bugfixes

- chore: unused parameters cleaned up
- bugfix: explicitly use hwy::Span in GenerateInternal() to prevent runtime crashes due to memory layout incompatibility
- bugfix: explicit nullptr check in LogDebug
- chore: length-related parameters renamed for clarity
- feature: SaveConversation() can be optionally used to save copy of a conversation that ResetConversation() will rewind to upon request, rather than just an empty KV cache
- feature: GetCurrentConversation() can be used to query the current conversation's name

PiperOrigin-RevId: 755873147
This commit is contained in:
The gemma.cpp Authors 2025-05-07 08:52:04 -07:00 committed by Copybara-Service
parent e9ecb7794d
commit 20757046db
5 changed files with 233 additions and 74 deletions

View File

@ -1,3 +1,18 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
@ -35,7 +50,7 @@ namespace GemmaCpp
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType, [MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath, [MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType, [MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
int maxLength); int maxGeneratedTokens);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaDestroy(IntPtr context); private static extern void GemmaDestroy(IntPtr context);
@ -56,7 +71,7 @@ namespace GemmaCpp
IntPtr context, IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, [MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
[Out] byte[] output, [Out] byte[] output,
int maxLength, int maxOutputChars,
GemmaTokenCallback callback, GemmaTokenCallback callback,
IntPtr userData); IntPtr userData);
@ -68,7 +83,7 @@ namespace GemmaCpp
int image_width, // Added dimension int image_width, // Added dimension
int image_height, // Added dimension int image_height, // Added dimension
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal [MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
int maxLength, int maxOutputChars,
GemmaTokenCallback callback, GemmaTokenCallback callback,
IntPtr userData); IntPtr userData);
@ -120,6 +135,13 @@ namespace GemmaCpp
IntPtr context, IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName); [MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")]
[return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string
private static extern string GemmaGetCurrentConversation(IntPtr context);
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")]
private static extern void GemmaSaveConversation(IntPtr context);
// Native callback delegate type // Native callback delegate type
[UnmanagedFunctionPointer(CallingConvention.Cdecl)] [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void GemmaLogCallback( private delegate void GemmaLogCallback(
@ -135,9 +157,9 @@ namespace GemmaCpp
private GCHandle _logCallbackHandle; private GCHandle _logCallbackHandle;
private bool _loggingEnabled = false; private bool _loggingEnabled = false;
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192) public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192)
{ {
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength); _context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens);
if (_context == IntPtr.Zero) if (_context == IntPtr.Zero)
{ {
throw new GemmaException("Failed to create Gemma context"); throw new GemmaException("Failed to create Gemma context");
@ -281,6 +303,31 @@ namespace GemmaCpp
return result; return result;
} }
public string GetCurrentConversation()
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method
Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'");
return currentConversation;
}
public void SaveConversation()
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));
if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
GemmaSaveConversation(_context);
Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache.");
}
public int CountTokens(string prompt) public int CountTokens(string prompt)
{ {
if (_disposed) if (_disposed)
@ -292,12 +339,12 @@ namespace GemmaCpp
return count; return count;
} }
public string Generate(string prompt, int maxLength = 4096) public string Generate(string prompt, int maxOutputChars = 4096)
{ {
return Generate(prompt, null, maxLength); return Generate(prompt, null, maxOutputChars);
} }
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096) public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096)
{ {
if (_disposed) if (_disposed)
throw new ObjectDisposedException(nameof(Gemma)); throw new ObjectDisposedException(nameof(Gemma));
@ -305,7 +352,7 @@ namespace GemmaCpp
if (_context == IntPtr.Zero) if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid"); throw new GemmaException("Gemma context is invalid");
var outputBuffer = new byte[maxLength * 4]; // Allow for worst case UTF-8 size var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size
GemmaTokenCallback nativeCallback = null; GemmaTokenCallback nativeCallback = null;
// Track token count for debugging // Track token count for debugging
@ -327,7 +374,7 @@ namespace GemmaCpp
try try
{ {
int length = GemmaGenerate(_context, prompt, outputBuffer, maxLength, int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars,
nativeCallback, IntPtr.Zero); nativeCallback, IntPtr.Zero);
if (length < 0) if (length < 0)
@ -346,13 +393,13 @@ namespace GemmaCpp
} }
} }
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxLength = 4096) public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096)
{ {
// Pass width and height to the overloaded method // Pass width and height to the overloaded method
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxLength); return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars);
} }
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxLength = 4096) public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096)
{ {
if (_disposed) if (_disposed)
throw new ObjectDisposedException(nameof(Gemma)); throw new ObjectDisposedException(nameof(Gemma));
@ -369,7 +416,7 @@ namespace GemmaCpp
if (imageData.Length < imageWidth * imageHeight * 3) if (imageData.Length < imageWidth * imageHeight * 3)
throw new ArgumentException("Image data array is too small for the specified dimensions"); throw new ArgumentException("Image data array is too small for the specified dimensions");
var output = new StringBuilder(maxLength); var output = new StringBuilder(maxOutputChars);
GemmaTokenCallback nativeCallback = null; GemmaTokenCallback nativeCallback = null;
if (callback != null) if (callback != null)
@ -386,7 +433,7 @@ namespace GemmaCpp
IntPtr imagePtr = imageHandle.AddrOfPinnedObject(); IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
// Pass image dimensions to the native call // Pass image dimensions to the native call
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxLength, int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars,
nativeCallback, IntPtr.Zero); nativeCallback, IntPtr.Zero);
if (length < 0) if (length < 0)

View File

@ -22,37 +22,38 @@
extern "C" { extern "C" {
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path, const char* weights_path,
const char* weight_type, int max_length) { int max_generated_tokens) {
try { try {
GemmaContext* ctx = GemmaContext::Create( GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path,
tokenizer_path, model_type, weights_path, weight_type, max_length); max_generated_tokens);
return ctx; return ctx;
} catch (...) { } catch (...) {
return nullptr; return nullptr;
} }
} }
GEMMA_API void GemmaDestroy(GemmaContext* ctx) { delete ctx; } GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
delete ctx;
}
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback, int max_output_chars, GemmaTokenCallback callback,
void* user_data) { void* user_data) {
if (!ctx) return -1; if (!ctx) return -1;
return ctx->Generate(prompt, output, max_length, callback, user_data); return ctx->Generate(prompt, output, max_output_chars, callback, user_data);
} }
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt, GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
const void* image_data, int image_width, const void* image_data, int image_width,
int image_height, char* output, int image_height, char* output,
int max_length, int max_output_chars,
GemmaTokenCallback callback, GemmaTokenCallback callback,
void* user_data) { void* user_data) {
if (!ctx) return -1; if (!ctx) return -1;
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height, return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
output, max_length, callback, user_data); output, max_output_chars, callback, user_data);
} }
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) { GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
@ -125,4 +126,14 @@ GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
if (!ctx || !conversation_name) return 0; if (!ctx || !conversation_name) return 0;
return ctx->HasConversation(conversation_name) ? 1 : 0; return ctx->HasConversation(conversation_name) ? 1 : 0;
} }
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) {
if (!ctx) return nullptr;
return ctx->GetCurrentConversation();
}
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) {
if (!ctx) return;
ctx->SaveConversation();
}
} }

View File

@ -42,18 +42,16 @@ typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data); typedef void (*GemmaLogCallback)(const char* message, void* user_data);
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path, const char* weights_path,
const char* weight_type, int max_length); int max_generated_tokens);
GEMMA_API void GemmaDestroy(GemmaContext* ctx); GEMMA_API void GemmaDestroy(GemmaContext* ctx);
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback, int max_output_chars, GemmaTokenCallback callback,
void* user_data); void* user_data);
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt, GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
const void* image_data, // Renamed param const void* image_data, int image_width,
int image_width, // Added dimension int image_height, char* output,
int image_height, // Added dimension int max_output_chars,
char* output, int max_length,
GemmaTokenCallback callback, GemmaTokenCallback callback,
void* user_data); void* user_data);
@ -67,17 +65,19 @@ GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value); GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value); GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value); GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Renamed GEMMA_API void GemmaResetConversation(GemmaContext* ctx);
// Conversation management functions (renamed) // Conversation management functions (renamed)
GEMMA_API int GemmaCreateConversation( GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
GemmaContext* ctx, const char* conversation_name); // Renamed const char* conversation_name);
GEMMA_API int GemmaSwitchConversation( GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
GemmaContext* ctx, const char* conversation_name); // Renamed const char* conversation_name);
GEMMA_API int GemmaDeleteConversation( GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
GemmaContext* ctx, const char* conversation_name); // Renamed const char* conversation_name);
GEMMA_API int GemmaHasConversation(GemmaContext* ctx, GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
const char* conversation_name); // Renamed const char* conversation_name);
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx);
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -44,23 +44,36 @@ namespace gcpp {
// ConversationData constructor implementation // ConversationData constructor implementation
ConversationData::ConversationData(const ModelConfig& model_config, ConversationData::ConversationData(const ModelConfig& model_config,
size_t prefill_tbatch_size) size_t prefill_tbatch_size)
: kv_cache(std::make_unique<KVCache>( : model_config_ref_(model_config),
prefill_tbatch_size_(prefill_tbatch_size),
kv_cache(std::make_unique<KVCache>(
KVCache::Create(model_config, prefill_tbatch_size))), KVCache::Create(model_config, prefill_tbatch_size))),
abs_pos(0) {} abs_pos(0) {}
// ConversationData copy constructor implementation
ConversationData::ConversationData(const ConversationData& other)
: model_config_ref_(other.model_config_ref_),
prefill_tbatch_size_(other.prefill_tbatch_size_),
kv_cache(nullptr),
abs_pos(other.abs_pos) {
if (other.kv_cache) {
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy(
other.model_config_ref_, other.prefill_tbatch_size_));
}
}
// Initialize static members // Initialize static members
GemmaLogCallback GemmaContext::s_log_callback = nullptr; GemmaLogCallback GemmaContext::s_log_callback = nullptr;
void* GemmaContext::s_log_user_data = nullptr; void* GemmaContext::s_log_user_data = nullptr;
GemmaContext* GemmaContext::Create(const char* tokenizer_path, GemmaContext* GemmaContext::Create(const char* tokenizer_path,
const char* ignored1,
const char* weights_path, const char* weights_path,
const char* ignored2, int max_length) { int max_generated_tokens) {
std::stringstream ss; std::stringstream ss;
ss << "Creating GemmaContext with tokenizer_path: " ss << "Creating GemmaContext with tokenizer_path: "
<< (tokenizer_path ? tokenizer_path : "null") << (tokenizer_path ? tokenizer_path : "null")
<< ", weights_path: " << (weights_path ? weights_path : "null") << ", weights_path: " << (weights_path ? weights_path : "null")
<< ", max_length: " << max_length; << ", max_generated_tokens: " << max_generated_tokens;
LogDebug(ss.str().c_str()); LogDebug(ss.str().c_str());
ThreadingArgs threading_args; ThreadingArgs threading_args;
@ -73,27 +86,30 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
LogDebug("Initializing inference args"); LogDebug("Initializing inference args");
InferenceArgs inference_args; InferenceArgs inference_args;
inference_args.Init(); inference_args.Init();
inference_args.max_generated_tokens = max_length; inference_args.max_generated_tokens = max_generated_tokens;
inference_args.temperature = 0.7f; inference_args.temperature = 0.7f;
inference_args.top_k = 1; inference_args.top_k = 1;
inference_args.deterministic = false; inference_args.deterministic = false;
ss.str(""); ss.str("");
ss << "Inference args initialized with max_tokens: " << max_length ss << "Inference args initialized with max_tokens: " << max_generated_tokens
<< ", temperature: " << inference_args.temperature << ", temperature: " << inference_args.temperature
<< ", top_k: " << inference_args.top_k << ", deterministic: " << ", top_k: " << inference_args.top_k << ", deterministic: "
<< (inference_args.deterministic ? "true" : "false"); << (inference_args.deterministic ? "true" : "false");
LogDebug(ss.str().c_str()); LogDebug(ss.str().c_str());
return new GemmaContext(loader, inference_args, threading_args, max_length); return new GemmaContext(loader, inference_args, threading_args,
max_generated_tokens);
} }
GemmaContext::GemmaContext(const LoaderArgs& loader, GemmaContext::GemmaContext(const LoaderArgs& loader,
const InferenceArgs& inference_args, const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_length) const ThreadingArgs& threading_args,
int max_generated_tokens)
: inference_args(inference_args), : inference_args(inference_args),
threading_args(threading_args), threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args)), matmul_env(MakeMatMulEnv(threading_args)),
active_conversation_name("default"),
model(loader, matmul_env) { model(loader, matmul_env) {
std::stringstream ss; std::stringstream ss;
@ -114,7 +130,8 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
int GemmaContext::GenerateInternal(const char* prompt_string, int GemmaContext::GenerateInternal(const char* prompt_string,
const void* image_data, int image_width, const void* image_data, int image_width,
int image_height, char* output, int image_height, char* output,
int max_length, GemmaTokenCallback callback, int max_output_chars,
GemmaTokenCallback callback,
void* user_data) { void* user_data) {
PROFILER_ZONE("Gen.Internal"); PROFILER_ZONE("Gen.Internal");
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
@ -224,8 +241,12 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
return -1; return -1;
} }
// Create a span from the prompt vector - Generate() expects a hwy::Span,
// which has a different memory footprint to that of a std::vector.
hwy::Span<const int> prompt_span(prompt.data(), prompt.size());
// Pass the KVCache object by reference from the active conversation // Pass the KVCache object by reference from the active conversation
model.Generate(runtime_config, prompt, active_conversation->abs_pos, model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
prefix_end, *(active_conversation->kv_cache), timing_info); prefix_end, *(active_conversation->kv_cache), timing_info);
// prepare for next turn // prepare for next turn
@ -251,26 +272,26 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
} }
// Copy result buffer to output C-string (ensure null termination) // Copy result buffer to output C-string (ensure null termination)
strncpy(output, result_buffer.c_str(), max_length - 1); strncpy(output, result_buffer.c_str(), max_output_chars - 1);
output[max_length - 1] = '\0'; // Explicit null termination output[max_output_chars - 1] = '\0';
return static_cast<int>(strlen(output)); // Return length of the C-string return static_cast<int>(strlen(output));
} }
// Public Generate method (wrapper for text-only) // Public Generate method (wrapper for text-only)
int GemmaContext::Generate(const char* prompt_string, char* output, int GemmaContext::Generate(const char* prompt_string, char* output,
int max_length, GemmaTokenCallback callback, int max_output_chars, GemmaTokenCallback callback,
void* user_data) { void* user_data) {
// Call the internal implementation with null image_data and 0 dimensions // Call the internal implementation with null image_data and 0 dimensions
return GenerateInternal(prompt_string, nullptr, 0, 0, output, max_length, return GenerateInternal(prompt_string, nullptr, 0, 0, output,
callback, user_data); max_output_chars, callback, user_data);
} }
// Public GenerateMultimodal method (wrapper) // Public GenerateMultimodal method (wrapper)
int GemmaContext::GenerateMultimodal(const char* prompt_string, int GemmaContext::GenerateMultimodal(const char* prompt_string,
const void* image_data, int image_width, const void* image_data, int image_width,
int image_height, // Added dimensions int image_height, char* output,
char* output, int max_length, int max_output_chars,
GemmaTokenCallback callback, GemmaTokenCallback callback,
void* user_data) { void* user_data) {
if (image_data == nullptr) { if (image_data == nullptr) {
@ -283,7 +304,7 @@ int GemmaContext::GenerateMultimodal(const char* prompt_string,
} }
return GenerateInternal(prompt_string, image_data, image_width, image_height, return GenerateInternal(prompt_string, image_data, image_width, image_height,
output, max_length, callback, user_data); output, max_output_chars, callback, user_data);
} }
int GemmaContext::CountTokens(const char* text) { int GemmaContext::CountTokens(const char* text) {
@ -320,4 +341,9 @@ int GemmaContext::CountTokens(const char* text) {
} }
} }
// Get the name of the currently active conversation
const char* GemmaContext::GetCurrentConversation() {
return active_conversation_name.c_str();
}
} // namespace gcpp } // namespace gcpp

View File

@ -43,12 +43,17 @@ struct KVCache;
// Struct to hold data for a single conversation thread // Struct to hold data for a single conversation thread
struct ConversationData { struct ConversationData {
public:
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
ConversationData(const ConversationData& other);
private:
const ModelConfig& model_config_ref_;
size_t prefill_tbatch_size_;
public:
std::unique_ptr<KVCache> kv_cache; std::unique_ptr<KVCache> kv_cache;
size_t abs_pos = 0; size_t abs_pos = 0;
// Constructor to initialize kv_cache (requires KVCache definition or forward
// declaration)
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
}; };
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
@ -57,20 +62,20 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext { class GemmaContext {
private: private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_length); const ThreadingArgs& threading_args, int max_generated_tokens);
public: public:
static GemmaContext* Create(const char* tokenizer_path, const char* ignored1, static GemmaContext* Create(const char* tokenizer_path,
const char* weights_path, const char* ignored2, const char* weights_path,
int max_length); int max_generated_tokens);
// Returns length of generated text, or -1 on error // Returns length of generated text, or -1 on error
int Generate(const char* prompt_string, char* output, int max_length, int Generate(const char* prompt_string, char* output, int max_output_chars,
GemmaTokenCallback callback, void* user_data); GemmaTokenCallback callback, void* user_data);
// Returns length of generated text, or -1 on error // Returns length of generated text, or -1 on error
int GenerateMultimodal(const char* prompt_string, const void* image_data, int GenerateMultimodal(const char* prompt_string, const void* image_data,
int image_width, int image_height, char* output, int image_width, int image_height, char* output,
int max_length, GemmaTokenCallback callback, int max_output_chars, GemmaTokenCallback callback,
void* user_data); void* user_data);
// Returns number of tokens in text, or -1 on error // Returns number of tokens in text, or -1 on error
@ -122,15 +127,71 @@ class GemmaContext {
LogDebug("Setting prefill_tbatch_size to configured value"); LogDebug("Setting prefill_tbatch_size to configured value");
} }
void SaveConversation() {
if (!active_conversation || active_conversation_name.empty()) {
if (!active_conversation) {
LogDebug("SaveConversation: No active conversation to save.");
} else { // active_conversation_name must be empty
LogDebug(
"SaveConversation: Active conversation name is empty. Cannot "
"save.");
}
return;
}
std::string log_msg = "SaveConversation: Attempting to save '";
log_msg += active_conversation_name;
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
// Create a deep copy of the active_conversation.
// The ConversationData copy constructor handles the deep copy of KVCache.
auto conversation_copy =
std::make_shared<ConversationData>(*active_conversation);
// Store the deep copy in prewarmed_cache.
// If a conversation with the same name already exists, it will be
// overwritten. std::shared_ptr will handle the destruction of the old
// object if it's being replaced.
prewarmed_cache[active_conversation_name] = conversation_copy;
log_msg = "SaveConversation: Successfully saved '";
log_msg += active_conversation_name;
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
}
// Reset the currently active conversation // Reset the currently active conversation
void ResetConversation() { void ResetConversation() {
if (active_conversation) { if (active_conversation) {
LogDebug("Resetting active conversation"); std::string log_prefix = "ResetConversation ('";
log_prefix += active_conversation_name.empty() ? "[unnamed]"
: active_conversation_name;
log_prefix += "'): ";
LogDebug((log_prefix + "Attempting to reset.").c_str());
// Attempt to restore from prewarmed_cache first, regardless of name.
auto it = prewarmed_cache.find(active_conversation_name);
if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) {
// Found in prewarmed_cache and the cached entry is valid.
LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.")
.c_str());
active_conversation->abs_pos = it->second->abs_pos;
// Perform a deep copy of the KVCache from the prewarmed version.
active_conversation->kv_cache =
std::make_unique<KVCache>(it->second->kv_cache->Copy(
model.GetModelConfig(), inference_args.prefill_tbatch_size));
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
.c_str());
return;
}
// If not found in prewarmed_cache or prewarmed_cache entry is invalid,
// rewind to initial state.
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object // Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create( active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
model.GetModelConfig(), inference_args.prefill_tbatch_size)); model.GetModelConfig(), inference_args.prefill_tbatch_size));
LogDebug("Active conversation reset");
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else { } else {
LogDebug("Cannot reset conversation: active_conversation is null"); LogDebug("Cannot reset conversation: active_conversation is null");
} }
@ -160,6 +221,7 @@ class GemmaContext {
} }
LogDebug("Switching active conversation"); LogDebug("Switching active conversation");
active_conversation = it->second; active_conversation = it->second;
active_conversation_name = conversation_name;
return true; return true;
} }
@ -183,6 +245,12 @@ class GemmaContext {
LogDebug("Deleting conversation"); LogDebug("Deleting conversation");
conversation_cache.erase(it); conversation_cache.erase(it);
auto it2 = prewarmed_cache.find(name);
if (it2 != prewarmed_cache.end()) {
prewarmed_cache.erase(it2);
}
return true; return true;
} }
@ -192,13 +260,16 @@ class GemmaContext {
return conversation_cache.count(name); return conversation_cache.count(name);
} }
// Get the name of the currently active conversation
const char* GetCurrentConversation();
private: private:
// Internal implementation shared by Generate and GenerateMultimodal // Internal implementation shared by Generate and GenerateMultimodal
int GenerateInternal(const char* prompt_string, int GenerateInternal(const char* prompt_string,
const void* image_data, // Null for text-only generation const void* image_data, // Null for text-only generation
int image_width, // Added dimension (0 if no image) int image_width,
int image_height, // Added dimension (0 if no image) int image_height,
char* output, int max_length, char* output, int max_output_chars,
GemmaTokenCallback callback, void* user_data); GemmaTokenCallback callback, void* user_data);
// Pointer to the currently active conversation's data // Pointer to the currently active conversation's data
@ -207,6 +278,8 @@ class GemmaContext {
// Cache of all named conversations // Cache of all named conversations
std::unordered_map<std::string, std::shared_ptr<ConversationData>> std::unordered_map<std::string, std::shared_ptr<ConversationData>>
conversation_cache; conversation_cache;
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
prewarmed_cache;
// Buffers (potentially could be moved into ConversationData if needed // Buffers (potentially could be moved into ConversationData if needed
// per-conversation) // per-conversation)
@ -219,6 +292,8 @@ class GemmaContext {
ThreadingArgs threading_args; ThreadingArgs threading_args;
MatMulEnv matmul_env; MatMulEnv matmul_env;
std::string active_conversation_name;
// Model itself (don't move this, needs to be below the args above) // Model itself (don't move this, needs to be below the args above)
Gemma model; Gemma model;
@ -232,7 +307,7 @@ class GemmaContext {
// Use logging helper method to print messages into a managed callback if // Use logging helper method to print messages into a managed callback if
// necessary // necessary
static void LogDebug(const char* message) { static void LogDebug(const char* message) {
if (s_log_callback) { if (s_log_callback != nullptr) {
s_log_callback(message, s_log_user_data); s_log_callback(message, s_log_user_data);
} else { } else {
#ifdef _WIN32 #ifdef _WIN32