diff --git a/gemma/bindings/GemmaInterop.cs b/gemma/bindings/GemmaInterop.cs index 73eea7d..0fb3ee8 100644 --- a/gemma/bindings/GemmaInterop.cs +++ b/gemma/bindings/GemmaInterop.cs @@ -1,3 +1,18 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Diagnostics; using System.Runtime.InteropServices; @@ -35,7 +50,7 @@ namespace GemmaCpp [MarshalAs(UnmanagedType.LPUTF8Str)] string modelType, [MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath, [MarshalAs(UnmanagedType.LPUTF8Str)] string weightType, - int maxLength); + int maxGeneratedTokens); [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] private static extern void GemmaDestroy(IntPtr context); @@ -56,7 +71,7 @@ namespace GemmaCpp IntPtr context, [MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, [Out] byte[] output, - int maxLength, + int maxOutputChars, GemmaTokenCallback callback, IntPtr userData); @@ -68,7 +83,7 @@ namespace GemmaCpp int image_width, // Added dimension int image_height, // Added dimension [MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal - int maxLength, + int maxOutputChars, GemmaTokenCallback callback, IntPtr userData); @@ -120,6 +135,13 @@ namespace GemmaCpp IntPtr context, [MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName); + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")] + [return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string + private static extern string GemmaGetCurrentConversation(IntPtr context); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")] + private static extern void GemmaSaveConversation(IntPtr context); + // Native callback delegate type [UnmanagedFunctionPointer(CallingConvention.Cdecl)] private delegate void GemmaLogCallback( @@ -135,9 +157,9 @@ namespace GemmaCpp private GCHandle _logCallbackHandle; private bool _loggingEnabled = false; - public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192) + public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192) { - _context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength); + _context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens); if (_context == IntPtr.Zero) { throw new GemmaException("Failed to create Gemma context"); @@ -281,6 +303,31 @@ namespace GemmaCpp return result; } + public string GetCurrentConversation() + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method + Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'"); + return currentConversation; + } + + public void SaveConversation() + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaSaveConversation(_context); + Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache."); + } + public int CountTokens(string prompt) { if (_disposed) @@ -292,12 +339,12 @@ namespace GemmaCpp return count; } - public string Generate(string prompt, int maxLength = 4096) + public string Generate(string prompt, int maxOutputChars = 4096) { - return Generate(prompt, null, maxLength); + return Generate(prompt, null, maxOutputChars); } - public string Generate(string prompt, TokenCallback callback, int maxLength = 4096) + public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096) { if (_disposed) throw new ObjectDisposedException(nameof(Gemma)); @@ -305,7 +352,7 @@ namespace GemmaCpp if (_context == IntPtr.Zero) throw new GemmaException("Gemma context is invalid"); - var outputBuffer = new byte[maxLength * 4]; // Allow for worst case UTF-8 size + var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size GemmaTokenCallback nativeCallback = null; // Track token count for debugging @@ -327,7 +374,7 @@ namespace GemmaCpp try { - int length = GemmaGenerate(_context, prompt, outputBuffer, maxLength, + int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars, nativeCallback, IntPtr.Zero); if (length < 0) @@ -346,13 +393,13 @@ namespace GemmaCpp } } - public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxLength = 4096) + public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096) { // Pass width and height to the overloaded method - return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxLength); + return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars); } - public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxLength = 4096) + public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096) { if (_disposed) throw new ObjectDisposedException(nameof(Gemma)); @@ -369,7 +416,7 @@ namespace GemmaCpp if (imageData.Length < imageWidth * imageHeight * 3) throw new ArgumentException("Image data array is too small for the specified dimensions"); - var output = new StringBuilder(maxLength); + var output = new StringBuilder(maxOutputChars); GemmaTokenCallback nativeCallback = null; if (callback != null) @@ -386,7 +433,7 @@ namespace GemmaCpp IntPtr imagePtr = imageHandle.AddrOfPinnedObject(); // Pass image dimensions to the native call - int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxLength, + int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars, nativeCallback, IntPtr.Zero); if (length < 0) diff --git a/gemma/bindings/c_api.cc b/gemma/bindings/c_api.cc index e5efbc4..cba2ffb 100644 --- a/gemma/bindings/c_api.cc +++ b/gemma/bindings/c_api.cc @@ -22,37 +22,38 @@ extern "C" { GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, - const char* model_type, const char* weights_path, - const char* weight_type, int max_length) { + int max_generated_tokens) { try { - GemmaContext* ctx = GemmaContext::Create( - tokenizer_path, model_type, weights_path, weight_type, max_length); + GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path, + max_generated_tokens); return ctx; } catch (...) { return nullptr; } } -GEMMA_API void GemmaDestroy(GemmaContext* ctx) { delete ctx; } +GEMMA_API void GemmaDestroy(GemmaContext* ctx) { + delete ctx; +} GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, - int max_length, GemmaTokenCallback callback, + int max_output_chars, GemmaTokenCallback callback, void* user_data) { if (!ctx) return -1; - return ctx->Generate(prompt, output, max_length, callback, user_data); + return ctx->Generate(prompt, output, max_output_chars, callback, user_data); } GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt, const void* image_data, int image_width, int image_height, char* output, - int max_length, + int max_output_chars, GemmaTokenCallback callback, void* user_data) { if (!ctx) return -1; return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height, - output, max_length, callback, user_data); + output, max_output_chars, callback, user_data); } GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) { @@ -125,4 +126,14 @@ GEMMA_API int GemmaHasConversation(GemmaContext* ctx, if (!ctx || !conversation_name) return 0; return ctx->HasConversation(conversation_name) ? 1 : 0; } + +GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) { + if (!ctx) return nullptr; + return ctx->GetCurrentConversation(); +} + +GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) { + if (!ctx) return; + ctx->SaveConversation(); +} } diff --git a/gemma/bindings/c_api.h b/gemma/bindings/c_api.h index 98e14f2..6d369b8 100644 --- a/gemma/bindings/c_api.h +++ b/gemma/bindings/c_api.h @@ -42,18 +42,16 @@ typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); typedef void (*GemmaLogCallback)(const char* message, void* user_data); GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, - const char* model_type, const char* weights_path, - const char* weight_type, int max_length); + int max_generated_tokens); GEMMA_API void GemmaDestroy(GemmaContext* ctx); GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, - int max_length, GemmaTokenCallback callback, + int max_output_chars, GemmaTokenCallback callback, void* user_data); GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt, - const void* image_data, // Renamed param - int image_width, // Added dimension - int image_height, // Added dimension - char* output, int max_length, + const void* image_data, int image_width, + int image_height, char* output, + int max_output_chars, GemmaTokenCallback callback, void* user_data); @@ -67,17 +65,19 @@ GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value); GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value); GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value); GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value); -GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Renamed +GEMMA_API void GemmaResetConversation(GemmaContext* ctx); // Conversation management functions (renamed) -GEMMA_API int GemmaCreateConversation( - GemmaContext* ctx, const char* conversation_name); // Renamed -GEMMA_API int GemmaSwitchConversation( - GemmaContext* ctx, const char* conversation_name); // Renamed -GEMMA_API int GemmaDeleteConversation( - GemmaContext* ctx, const char* conversation_name); // Renamed +GEMMA_API int GemmaCreateConversation(GemmaContext* ctx, + const char* conversation_name); +GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx, + const char* conversation_name); +GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx, + const char* conversation_name); GEMMA_API int GemmaHasConversation(GemmaContext* ctx, - const char* conversation_name); // Renamed + const char* conversation_name); +GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx); +GEMMA_API void GemmaSaveConversation(GemmaContext* ctx); #ifdef __cplusplus } diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index f6242d2..38ca070 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -44,23 +44,36 @@ namespace gcpp { // ConversationData constructor implementation ConversationData::ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size) - : kv_cache(std::make_unique( + : model_config_ref_(model_config), + prefill_tbatch_size_(prefill_tbatch_size), + kv_cache(std::make_unique( KVCache::Create(model_config, prefill_tbatch_size))), abs_pos(0) {} +// ConversationData copy constructor implementation +ConversationData::ConversationData(const ConversationData& other) + : model_config_ref_(other.model_config_ref_), + prefill_tbatch_size_(other.prefill_tbatch_size_), + kv_cache(nullptr), + abs_pos(other.abs_pos) { + if (other.kv_cache) { + kv_cache = std::make_unique(other.kv_cache->Copy( + other.model_config_ref_, other.prefill_tbatch_size_)); + } +} + // Initialize static members GemmaLogCallback GemmaContext::s_log_callback = nullptr; void* GemmaContext::s_log_user_data = nullptr; GemmaContext* GemmaContext::Create(const char* tokenizer_path, - const char* ignored1, const char* weights_path, - const char* ignored2, int max_length) { + int max_generated_tokens) { std::stringstream ss; ss << "Creating GemmaContext with tokenizer_path: " << (tokenizer_path ? tokenizer_path : "null") << ", weights_path: " << (weights_path ? weights_path : "null") - << ", max_length: " << max_length; + << ", max_generated_tokens: " << max_generated_tokens; LogDebug(ss.str().c_str()); ThreadingArgs threading_args; @@ -73,27 +86,30 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path, LogDebug("Initializing inference args"); InferenceArgs inference_args; inference_args.Init(); - inference_args.max_generated_tokens = max_length; + inference_args.max_generated_tokens = max_generated_tokens; inference_args.temperature = 0.7f; inference_args.top_k = 1; inference_args.deterministic = false; ss.str(""); - ss << "Inference args initialized with max_tokens: " << max_length + ss << "Inference args initialized with max_tokens: " << max_generated_tokens << ", temperature: " << inference_args.temperature << ", top_k: " << inference_args.top_k << ", deterministic: " << (inference_args.deterministic ? "true" : "false"); LogDebug(ss.str().c_str()); - return new GemmaContext(loader, inference_args, threading_args, max_length); + return new GemmaContext(loader, inference_args, threading_args, + max_generated_tokens); } GemmaContext::GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, - const ThreadingArgs& threading_args, int max_length) + const ThreadingArgs& threading_args, + int max_generated_tokens) : inference_args(inference_args), threading_args(threading_args), matmul_env(MakeMatMulEnv(threading_args)), + active_conversation_name("default"), model(loader, matmul_env) { std::stringstream ss; @@ -114,7 +130,8 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, int GemmaContext::GenerateInternal(const char* prompt_string, const void* image_data, int image_width, int image_height, char* output, - int max_length, GemmaTokenCallback callback, + int max_output_chars, + GemmaTokenCallback callback, void* user_data) { PROFILER_ZONE("Gen.Internal"); size_t tokens_generated_this_turn = 0; // differentiates prefill from reply @@ -224,8 +241,12 @@ int GemmaContext::GenerateInternal(const char* prompt_string, return -1; } + // Create a span from the prompt vector - Generate() expects a hwy::Span, + // which has a different memory footprint to that of a std::vector. + hwy::Span prompt_span(prompt.data(), prompt.size()); + // Pass the KVCache object by reference from the active conversation - model.Generate(runtime_config, prompt, active_conversation->abs_pos, + model.Generate(runtime_config, prompt_span, active_conversation->abs_pos, prefix_end, *(active_conversation->kv_cache), timing_info); // prepare for next turn @@ -251,26 +272,26 @@ int GemmaContext::GenerateInternal(const char* prompt_string, } // Copy result buffer to output C-string (ensure null termination) - strncpy(output, result_buffer.c_str(), max_length - 1); - output[max_length - 1] = '\0'; // Explicit null termination + strncpy(output, result_buffer.c_str(), max_output_chars - 1); + output[max_output_chars - 1] = '\0'; - return static_cast(strlen(output)); // Return length of the C-string + return static_cast(strlen(output)); } // Public Generate method (wrapper for text-only) int GemmaContext::Generate(const char* prompt_string, char* output, - int max_length, GemmaTokenCallback callback, + int max_output_chars, GemmaTokenCallback callback, void* user_data) { // Call the internal implementation with null image_data and 0 dimensions - return GenerateInternal(prompt_string, nullptr, 0, 0, output, max_length, - callback, user_data); + return GenerateInternal(prompt_string, nullptr, 0, 0, output, + max_output_chars, callback, user_data); } // Public GenerateMultimodal method (wrapper) int GemmaContext::GenerateMultimodal(const char* prompt_string, const void* image_data, int image_width, - int image_height, // Added dimensions - char* output, int max_length, + int image_height, char* output, + int max_output_chars, GemmaTokenCallback callback, void* user_data) { if (image_data == nullptr) { @@ -283,7 +304,7 @@ int GemmaContext::GenerateMultimodal(const char* prompt_string, } return GenerateInternal(prompt_string, image_data, image_width, image_height, - output, max_length, callback, user_data); + output, max_output_chars, callback, user_data); } int GemmaContext::CountTokens(const char* text) { @@ -320,4 +341,9 @@ int GemmaContext::CountTokens(const char* text) { } } +// Get the name of the currently active conversation +const char* GemmaContext::GetCurrentConversation() { + return active_conversation_name.c_str(); +} + } // namespace gcpp diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index 6202f2a..ba44c1b 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -43,12 +43,17 @@ struct KVCache; // Struct to hold data for a single conversation thread struct ConversationData { + public: + ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size); + ConversationData(const ConversationData& other); + + private: + const ModelConfig& model_config_ref_; + size_t prefill_tbatch_size_; + + public: std::unique_ptr kv_cache; size_t abs_pos = 0; - - // Constructor to initialize kv_cache (requires KVCache definition or forward - // declaration) - ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size); }; typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); @@ -57,20 +62,20 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data); class GemmaContext { private: GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, - const ThreadingArgs& threading_args, int max_length); + const ThreadingArgs& threading_args, int max_generated_tokens); public: - static GemmaContext* Create(const char* tokenizer_path, const char* ignored1, - const char* weights_path, const char* ignored2, - int max_length); + static GemmaContext* Create(const char* tokenizer_path, + const char* weights_path, + int max_generated_tokens); // Returns length of generated text, or -1 on error - int Generate(const char* prompt_string, char* output, int max_length, + int Generate(const char* prompt_string, char* output, int max_output_chars, GemmaTokenCallback callback, void* user_data); // Returns length of generated text, or -1 on error int GenerateMultimodal(const char* prompt_string, const void* image_data, int image_width, int image_height, char* output, - int max_length, GemmaTokenCallback callback, + int max_output_chars, GemmaTokenCallback callback, void* user_data); // Returns number of tokens in text, or -1 on error @@ -122,15 +127,71 @@ class GemmaContext { LogDebug("Setting prefill_tbatch_size to configured value"); } + void SaveConversation() { + if (!active_conversation || active_conversation_name.empty()) { + if (!active_conversation) { + LogDebug("SaveConversation: No active conversation to save."); + } else { // active_conversation_name must be empty + LogDebug( + "SaveConversation: Active conversation name is empty. Cannot " + "save."); + } + return; + } + std::string log_msg = "SaveConversation: Attempting to save '"; + log_msg += active_conversation_name; + log_msg += "' to prewarmed_cache."; + LogDebug(log_msg.c_str()); + + // Create a deep copy of the active_conversation. + // The ConversationData copy constructor handles the deep copy of KVCache. + auto conversation_copy = + std::make_shared(*active_conversation); + + // Store the deep copy in prewarmed_cache. + // If a conversation with the same name already exists, it will be + // overwritten. std::shared_ptr will handle the destruction of the old + // object if it's being replaced. + prewarmed_cache[active_conversation_name] = conversation_copy; + + log_msg = "SaveConversation: Successfully saved '"; + log_msg += active_conversation_name; + log_msg += "' to prewarmed_cache."; + LogDebug(log_msg.c_str()); + } + // Reset the currently active conversation void ResetConversation() { if (active_conversation) { - LogDebug("Resetting active conversation"); + std::string log_prefix = "ResetConversation ('"; + log_prefix += active_conversation_name.empty() ? "[unnamed]" + : active_conversation_name; + log_prefix += "'): "; + LogDebug((log_prefix + "Attempting to reset.").c_str()); + // Attempt to restore from prewarmed_cache first, regardless of name. + auto it = prewarmed_cache.find(active_conversation_name); + if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) { + // Found in prewarmed_cache and the cached entry is valid. + LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.") + .c_str()); + active_conversation->abs_pos = it->second->abs_pos; + // Perform a deep copy of the KVCache from the prewarmed version. + active_conversation->kv_cache = + std::make_unique(it->second->kv_cache->Copy( + model.GetModelConfig(), inference_args.prefill_tbatch_size)); + LogDebug((log_prefix + "Successfully restored from prewarmed_cache.") + .c_str()); + return; + } + + // If not found in prewarmed_cache or prewarmed_cache entry is invalid, + // rewind to initial state. active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object active_conversation->kv_cache = std::make_unique(KVCache::Create( model.GetModelConfig(), inference_args.prefill_tbatch_size)); - LogDebug("Active conversation reset"); + + LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { LogDebug("Cannot reset conversation: active_conversation is null"); } @@ -160,6 +221,7 @@ class GemmaContext { } LogDebug("Switching active conversation"); active_conversation = it->second; + active_conversation_name = conversation_name; return true; } @@ -183,6 +245,12 @@ class GemmaContext { LogDebug("Deleting conversation"); conversation_cache.erase(it); + + auto it2 = prewarmed_cache.find(name); + if (it2 != prewarmed_cache.end()) { + prewarmed_cache.erase(it2); + } + return true; } @@ -192,13 +260,16 @@ class GemmaContext { return conversation_cache.count(name); } + // Get the name of the currently active conversation + const char* GetCurrentConversation(); + private: // Internal implementation shared by Generate and GenerateMultimodal int GenerateInternal(const char* prompt_string, const void* image_data, // Null for text-only generation - int image_width, // Added dimension (0 if no image) - int image_height, // Added dimension (0 if no image) - char* output, int max_length, + int image_width, + int image_height, + char* output, int max_output_chars, GemmaTokenCallback callback, void* user_data); // Pointer to the currently active conversation's data @@ -207,6 +278,8 @@ class GemmaContext { // Cache of all named conversations std::unordered_map> conversation_cache; + std::unordered_map> + prewarmed_cache; // Buffers (potentially could be moved into ConversationData if needed // per-conversation) @@ -219,6 +292,8 @@ class GemmaContext { ThreadingArgs threading_args; MatMulEnv matmul_env; + std::string active_conversation_name; + // Model itself (don't move this, needs to be below the args above) Gemma model; @@ -232,7 +307,7 @@ class GemmaContext { // Use logging helper method to print messages into a managed callback if // necessary static void LogDebug(const char* message) { - if (s_log_callback) { + if (s_log_callback != nullptr) { s_log_callback(message, s_log_user_data); } else { #ifdef _WIN32