mirror of https://github.com/google/gemma.cpp.git
cleanup, new conversation methods, bugfixes
- chore: unused parameters cleaned up - bugfix: explicitly use hwy::Span in GenerateInternal() to prevent runtime crashes due to memory layout incompatibility - bugfix: explicit nullptr check in LogDebug - chore: length-related parameters renamed for clarity - feature: SaveConversation() can be optionally used to save copy of a conversation that ResetConversation() will rewind to upon request, rather than just an empty KV cache - feature: GetCurrentConversation() can be used to query the current conversation's name PiperOrigin-RevId: 755873147
This commit is contained in:
parent
e9ecb7794d
commit
20757046db
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.InteropServices;
|
||||
|
|
@ -35,7 +50,7 @@ namespace GemmaCpp
|
|||
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
||||
int maxLength);
|
||||
int maxGeneratedTokens);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaDestroy(IntPtr context);
|
||||
|
|
@ -56,7 +71,7 @@ namespace GemmaCpp
|
|||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
[Out] byte[] output,
|
||||
int maxLength,
|
||||
int maxOutputChars,
|
||||
GemmaTokenCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
|
|
@ -68,7 +83,7 @@ namespace GemmaCpp
|
|||
int image_width, // Added dimension
|
||||
int image_height, // Added dimension
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
|
||||
int maxLength,
|
||||
int maxOutputChars,
|
||||
GemmaTokenCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
|
|
@ -120,6 +135,13 @@ namespace GemmaCpp
|
|||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")]
|
||||
[return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string
|
||||
private static extern string GemmaGetCurrentConversation(IntPtr context);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")]
|
||||
private static extern void GemmaSaveConversation(IntPtr context);
|
||||
|
||||
// Native callback delegate type
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
private delegate void GemmaLogCallback(
|
||||
|
|
@ -135,9 +157,9 @@ namespace GemmaCpp
|
|||
private GCHandle _logCallbackHandle;
|
||||
private bool _loggingEnabled = false;
|
||||
|
||||
public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
|
||||
public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192)
|
||||
{
|
||||
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
|
||||
_context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens);
|
||||
if (_context == IntPtr.Zero)
|
||||
{
|
||||
throw new GemmaException("Failed to create Gemma context");
|
||||
|
|
@ -281,6 +303,31 @@ namespace GemmaCpp
|
|||
return result;
|
||||
}
|
||||
|
||||
public string GetCurrentConversation()
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'");
|
||||
return currentConversation;
|
||||
}
|
||||
|
||||
public void SaveConversation()
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSaveConversation(_context);
|
||||
Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache.");
|
||||
}
|
||||
|
||||
public int CountTokens(string prompt)
|
||||
{
|
||||
if (_disposed)
|
||||
|
|
@ -292,12 +339,12 @@ namespace GemmaCpp
|
|||
return count;
|
||||
}
|
||||
|
||||
public string Generate(string prompt, int maxLength = 4096)
|
||||
public string Generate(string prompt, int maxOutputChars = 4096)
|
||||
{
|
||||
return Generate(prompt, null, maxLength);
|
||||
return Generate(prompt, null, maxOutputChars);
|
||||
}
|
||||
|
||||
public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
|
||||
public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
|
@ -305,7 +352,7 @@ namespace GemmaCpp
|
|||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
var outputBuffer = new byte[maxLength * 4]; // Allow for worst case UTF-8 size
|
||||
var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size
|
||||
GemmaTokenCallback nativeCallback = null;
|
||||
|
||||
// Track token count for debugging
|
||||
|
|
@ -327,7 +374,7 @@ namespace GemmaCpp
|
|||
|
||||
try
|
||||
{
|
||||
int length = GemmaGenerate(_context, prompt, outputBuffer, maxLength,
|
||||
int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars,
|
||||
nativeCallback, IntPtr.Zero);
|
||||
|
||||
if (length < 0)
|
||||
|
|
@ -346,13 +393,13 @@ namespace GemmaCpp
|
|||
}
|
||||
}
|
||||
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxLength = 4096)
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096)
|
||||
{
|
||||
// Pass width and height to the overloaded method
|
||||
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxLength);
|
||||
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars);
|
||||
}
|
||||
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxLength = 4096)
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
|
@ -369,7 +416,7 @@ namespace GemmaCpp
|
|||
if (imageData.Length < imageWidth * imageHeight * 3)
|
||||
throw new ArgumentException("Image data array is too small for the specified dimensions");
|
||||
|
||||
var output = new StringBuilder(maxLength);
|
||||
var output = new StringBuilder(maxOutputChars);
|
||||
GemmaTokenCallback nativeCallback = null;
|
||||
|
||||
if (callback != null)
|
||||
|
|
@ -386,7 +433,7 @@ namespace GemmaCpp
|
|||
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
|
||||
|
||||
// Pass image dimensions to the native call
|
||||
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxLength,
|
||||
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars,
|
||||
nativeCallback, IntPtr.Zero);
|
||||
|
||||
if (length < 0)
|
||||
|
|
|
|||
|
|
@ -22,37 +22,38 @@
|
|||
extern "C" {
|
||||
|
||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||
const char* model_type,
|
||||
const char* weights_path,
|
||||
const char* weight_type, int max_length) {
|
||||
int max_generated_tokens) {
|
||||
try {
|
||||
GemmaContext* ctx = GemmaContext::Create(
|
||||
tokenizer_path, model_type, weights_path, weight_type, max_length);
|
||||
GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path,
|
||||
max_generated_tokens);
|
||||
return ctx;
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx) { delete ctx; }
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return -1;
|
||||
return ctx->Generate(prompt, output, max_length, callback, user_data);
|
||||
return ctx->Generate(prompt, output, max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_length,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return -1;
|
||||
|
||||
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
|
||||
output, max_length, callback, user_data);
|
||||
output, max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
|
||||
|
|
@ -125,4 +126,14 @@ GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
|||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->HasConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) {
|
||||
if (!ctx) return nullptr;
|
||||
return ctx->GetCurrentConversation();
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) {
|
||||
if (!ctx) return;
|
||||
ctx->SaveConversation();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,18 +42,16 @@ typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
|||
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||
|
||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||
const char* model_type,
|
||||
const char* weights_path,
|
||||
const char* weight_type, int max_length);
|
||||
int max_generated_tokens);
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
|
||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||
const void* image_data, // Renamed param
|
||||
int image_width, // Added dimension
|
||||
int image_height, // Added dimension
|
||||
char* output, int max_length,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
|
||||
|
|
@ -67,17 +65,19 @@ GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
|
|||
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
|
||||
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Renamed
|
||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx);
|
||||
|
||||
// Conversation management functions (renamed)
|
||||
GEMMA_API int GemmaCreateConversation(
|
||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
||||
GEMMA_API int GemmaSwitchConversation(
|
||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
||||
GEMMA_API int GemmaDeleteConversation(
|
||||
GemmaContext* ctx, const char* conversation_name); // Renamed
|
||||
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||
const char* conversation_name); // Renamed
|
||||
const char* conversation_name);
|
||||
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx);
|
||||
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,23 +44,36 @@ namespace gcpp {
|
|||
// ConversationData constructor implementation
|
||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||
size_t prefill_tbatch_size)
|
||||
: kv_cache(std::make_unique<KVCache>(
|
||||
: model_config_ref_(model_config),
|
||||
prefill_tbatch_size_(prefill_tbatch_size),
|
||||
kv_cache(std::make_unique<KVCache>(
|
||||
KVCache::Create(model_config, prefill_tbatch_size))),
|
||||
abs_pos(0) {}
|
||||
|
||||
// ConversationData copy constructor implementation
|
||||
ConversationData::ConversationData(const ConversationData& other)
|
||||
: model_config_ref_(other.model_config_ref_),
|
||||
prefill_tbatch_size_(other.prefill_tbatch_size_),
|
||||
kv_cache(nullptr),
|
||||
abs_pos(other.abs_pos) {
|
||||
if (other.kv_cache) {
|
||||
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy(
|
||||
other.model_config_ref_, other.prefill_tbatch_size_));
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize static members
|
||||
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
||||
void* GemmaContext::s_log_user_data = nullptr;
|
||||
|
||||
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||
const char* ignored1,
|
||||
const char* weights_path,
|
||||
const char* ignored2, int max_length) {
|
||||
int max_generated_tokens) {
|
||||
std::stringstream ss;
|
||||
ss << "Creating GemmaContext with tokenizer_path: "
|
||||
<< (tokenizer_path ? tokenizer_path : "null")
|
||||
<< ", weights_path: " << (weights_path ? weights_path : "null")
|
||||
<< ", max_length: " << max_length;
|
||||
<< ", max_generated_tokens: " << max_generated_tokens;
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
ThreadingArgs threading_args;
|
||||
|
|
@ -73,27 +86,30 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
|||
LogDebug("Initializing inference args");
|
||||
InferenceArgs inference_args;
|
||||
inference_args.Init();
|
||||
inference_args.max_generated_tokens = max_length;
|
||||
inference_args.max_generated_tokens = max_generated_tokens;
|
||||
inference_args.temperature = 0.7f;
|
||||
inference_args.top_k = 1;
|
||||
inference_args.deterministic = false;
|
||||
|
||||
ss.str("");
|
||||
ss << "Inference args initialized with max_tokens: " << max_length
|
||||
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
||||
<< ", temperature: " << inference_args.temperature
|
||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
||||
<< (inference_args.deterministic ? "true" : "false");
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
return new GemmaContext(loader, inference_args, threading_args, max_length);
|
||||
return new GemmaContext(loader, inference_args, threading_args,
|
||||
max_generated_tokens);
|
||||
}
|
||||
|
||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||
const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_length)
|
||||
const ThreadingArgs& threading_args,
|
||||
int max_generated_tokens)
|
||||
: inference_args(inference_args),
|
||||
threading_args(threading_args),
|
||||
matmul_env(MakeMatMulEnv(threading_args)),
|
||||
active_conversation_name("default"),
|
||||
model(loader, matmul_env) {
|
||||
std::stringstream ss;
|
||||
|
||||
|
|
@ -114,7 +130,8 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
|||
int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
PROFILER_ZONE("Gen.Internal");
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
|
|
@ -224,8 +241,12 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
return -1;
|
||||
}
|
||||
|
||||
// Create a span from the prompt vector - Generate() expects a hwy::Span,
|
||||
// which has a different memory footprint to that of a std::vector.
|
||||
hwy::Span<const int> prompt_span(prompt.data(), prompt.size());
|
||||
|
||||
// Pass the KVCache object by reference from the active conversation
|
||||
model.Generate(runtime_config, prompt, active_conversation->abs_pos,
|
||||
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
|
||||
prefix_end, *(active_conversation->kv_cache), timing_info);
|
||||
|
||||
// prepare for next turn
|
||||
|
|
@ -251,26 +272,26 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
}
|
||||
|
||||
// Copy result buffer to output C-string (ensure null termination)
|
||||
strncpy(output, result_buffer.c_str(), max_length - 1);
|
||||
output[max_length - 1] = '\0'; // Explicit null termination
|
||||
strncpy(output, result_buffer.c_str(), max_output_chars - 1);
|
||||
output[max_output_chars - 1] = '\0';
|
||||
|
||||
return static_cast<int>(strlen(output)); // Return length of the C-string
|
||||
return static_cast<int>(strlen(output));
|
||||
}
|
||||
|
||||
// Public Generate method (wrapper for text-only)
|
||||
int GemmaContext::Generate(const char* prompt_string, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
// Call the internal implementation with null image_data and 0 dimensions
|
||||
return GenerateInternal(prompt_string, nullptr, 0, 0, output, max_length,
|
||||
callback, user_data);
|
||||
return GenerateInternal(prompt_string, nullptr, 0, 0, output,
|
||||
max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
// Public GenerateMultimodal method (wrapper)
|
||||
int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, // Added dimensions
|
||||
char* output, int max_length,
|
||||
int image_height, char* output,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (image_data == nullptr) {
|
||||
|
|
@ -283,7 +304,7 @@ int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
|||
}
|
||||
|
||||
return GenerateInternal(prompt_string, image_data, image_width, image_height,
|
||||
output, max_length, callback, user_data);
|
||||
output, max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
int GemmaContext::CountTokens(const char* text) {
|
||||
|
|
@ -320,4 +341,9 @@ int GemmaContext::CountTokens(const char* text) {
|
|||
}
|
||||
}
|
||||
|
||||
// Get the name of the currently active conversation
|
||||
const char* GemmaContext::GetCurrentConversation() {
|
||||
return active_conversation_name.c_str();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -43,12 +43,17 @@ struct KVCache;
|
|||
|
||||
// Struct to hold data for a single conversation thread
|
||||
struct ConversationData {
|
||||
public:
|
||||
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
||||
ConversationData(const ConversationData& other);
|
||||
|
||||
private:
|
||||
const ModelConfig& model_config_ref_;
|
||||
size_t prefill_tbatch_size_;
|
||||
|
||||
public:
|
||||
std::unique_ptr<KVCache> kv_cache;
|
||||
size_t abs_pos = 0;
|
||||
|
||||
// Constructor to initialize kv_cache (requires KVCache definition or forward
|
||||
// declaration)
|
||||
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
||||
};
|
||||
|
||||
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||
|
|
@ -57,20 +62,20 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
|||
class GemmaContext {
|
||||
private:
|
||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_length);
|
||||
const ThreadingArgs& threading_args, int max_generated_tokens);
|
||||
|
||||
public:
|
||||
static GemmaContext* Create(const char* tokenizer_path, const char* ignored1,
|
||||
const char* weights_path, const char* ignored2,
|
||||
int max_length);
|
||||
static GemmaContext* Create(const char* tokenizer_path,
|
||||
const char* weights_path,
|
||||
int max_generated_tokens);
|
||||
|
||||
// Returns length of generated text, or -1 on error
|
||||
int Generate(const char* prompt_string, char* output, int max_length,
|
||||
int Generate(const char* prompt_string, char* output, int max_output_chars,
|
||||
GemmaTokenCallback callback, void* user_data);
|
||||
// Returns length of generated text, or -1 on error
|
||||
int GenerateMultimodal(const char* prompt_string, const void* image_data,
|
||||
int image_width, int image_height, char* output,
|
||||
int max_length, GemmaTokenCallback callback,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
|
||||
// Returns number of tokens in text, or -1 on error
|
||||
|
|
@ -122,15 +127,71 @@ class GemmaContext {
|
|||
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||
}
|
||||
|
||||
void SaveConversation() {
|
||||
if (!active_conversation || active_conversation_name.empty()) {
|
||||
if (!active_conversation) {
|
||||
LogDebug("SaveConversation: No active conversation to save.");
|
||||
} else { // active_conversation_name must be empty
|
||||
LogDebug(
|
||||
"SaveConversation: Active conversation name is empty. Cannot "
|
||||
"save.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::string log_msg = "SaveConversation: Attempting to save '";
|
||||
log_msg += active_conversation_name;
|
||||
log_msg += "' to prewarmed_cache.";
|
||||
LogDebug(log_msg.c_str());
|
||||
|
||||
// Create a deep copy of the active_conversation.
|
||||
// The ConversationData copy constructor handles the deep copy of KVCache.
|
||||
auto conversation_copy =
|
||||
std::make_shared<ConversationData>(*active_conversation);
|
||||
|
||||
// Store the deep copy in prewarmed_cache.
|
||||
// If a conversation with the same name already exists, it will be
|
||||
// overwritten. std::shared_ptr will handle the destruction of the old
|
||||
// object if it's being replaced.
|
||||
prewarmed_cache[active_conversation_name] = conversation_copy;
|
||||
|
||||
log_msg = "SaveConversation: Successfully saved '";
|
||||
log_msg += active_conversation_name;
|
||||
log_msg += "' to prewarmed_cache.";
|
||||
LogDebug(log_msg.c_str());
|
||||
}
|
||||
|
||||
// Reset the currently active conversation
|
||||
void ResetConversation() {
|
||||
if (active_conversation) {
|
||||
LogDebug("Resetting active conversation");
|
||||
std::string log_prefix = "ResetConversation ('";
|
||||
log_prefix += active_conversation_name.empty() ? "[unnamed]"
|
||||
: active_conversation_name;
|
||||
log_prefix += "'): ";
|
||||
LogDebug((log_prefix + "Attempting to reset.").c_str());
|
||||
// Attempt to restore from prewarmed_cache first, regardless of name.
|
||||
auto it = prewarmed_cache.find(active_conversation_name);
|
||||
if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) {
|
||||
// Found in prewarmed_cache and the cached entry is valid.
|
||||
LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.")
|
||||
.c_str());
|
||||
active_conversation->abs_pos = it->second->abs_pos;
|
||||
// Perform a deep copy of the KVCache from the prewarmed version.
|
||||
active_conversation->kv_cache =
|
||||
std::make_unique<KVCache>(it->second->kv_cache->Copy(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
// If not found in prewarmed_cache or prewarmed_cache entry is invalid,
|
||||
// rewind to initial state.
|
||||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||
LogDebug("Active conversation reset");
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
LogDebug("Cannot reset conversation: active_conversation is null");
|
||||
}
|
||||
|
|
@ -160,6 +221,7 @@ class GemmaContext {
|
|||
}
|
||||
LogDebug("Switching active conversation");
|
||||
active_conversation = it->second;
|
||||
active_conversation_name = conversation_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -183,6 +245,12 @@ class GemmaContext {
|
|||
|
||||
LogDebug("Deleting conversation");
|
||||
conversation_cache.erase(it);
|
||||
|
||||
auto it2 = prewarmed_cache.find(name);
|
||||
if (it2 != prewarmed_cache.end()) {
|
||||
prewarmed_cache.erase(it2);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -192,13 +260,16 @@ class GemmaContext {
|
|||
return conversation_cache.count(name);
|
||||
}
|
||||
|
||||
// Get the name of the currently active conversation
|
||||
const char* GetCurrentConversation();
|
||||
|
||||
private:
|
||||
// Internal implementation shared by Generate and GenerateMultimodal
|
||||
int GenerateInternal(const char* prompt_string,
|
||||
const void* image_data, // Null for text-only generation
|
||||
int image_width, // Added dimension (0 if no image)
|
||||
int image_height, // Added dimension (0 if no image)
|
||||
char* output, int max_length,
|
||||
int image_width,
|
||||
int image_height,
|
||||
char* output, int max_output_chars,
|
||||
GemmaTokenCallback callback, void* user_data);
|
||||
|
||||
// Pointer to the currently active conversation's data
|
||||
|
|
@ -207,6 +278,8 @@ class GemmaContext {
|
|||
// Cache of all named conversations
|
||||
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||
conversation_cache;
|
||||
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||
prewarmed_cache;
|
||||
|
||||
// Buffers (potentially could be moved into ConversationData if needed
|
||||
// per-conversation)
|
||||
|
|
@ -219,6 +292,8 @@ class GemmaContext {
|
|||
ThreadingArgs threading_args;
|
||||
MatMulEnv matmul_env;
|
||||
|
||||
std::string active_conversation_name;
|
||||
|
||||
// Model itself (don't move this, needs to be below the args above)
|
||||
Gemma model;
|
||||
|
||||
|
|
@ -232,7 +307,7 @@ class GemmaContext {
|
|||
// Use logging helper method to print messages into a managed callback if
|
||||
// necessary
|
||||
static void LogDebug(const char* message) {
|
||||
if (s_log_callback) {
|
||||
if (s_log_callback != nullptr) {
|
||||
s_log_callback(message, s_log_user_data);
|
||||
} else {
|
||||
#ifdef _WIN32
|
||||
|
|
|
|||
Loading…
Reference in New Issue