gemma.cpp/gemma/bindings/context.h

309 lines
10 KiB
C++

// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
#include <memory> // For std::shared_ptr, std::make_shared
#include <string>
#include <unordered_map>
#include <vector>
// Logging
#ifdef _WIN32
#include <windows.h>
#else
#include <stdio.h>
#endif
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "ops/matmul.h" // MatMulEnv
#include "hwy/base.h"
#include "hwy/highway.h"
namespace gcpp {
// Struct to hold data for a single conversation thread
struct ConversationData {
ConversationData(const ModelConfig& model_config,
const InferenceArgs& inference_args,
const Allocator& allocator);
ConversationData(const ConversationData& other);
std::unique_ptr<KVCache> kv_cache;
size_t abs_pos = 0;
};
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
class GemmaContext {
private:
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
const ThreadingArgs& threading_args, int max_generated_tokens);
public:
static GemmaContext* Create(const char* tokenizer_path,
const char* weights_path,
int max_generated_tokens);
// Returns length of generated text, or -1 on error
int Generate(const char* prompt_string, char* output, int max_output_chars,
GemmaTokenCallback callback, void* user_data);
// Returns length of generated text, or -1 on error
int GenerateMultimodal(const char* prompt_string, const void* image_data,
int image_width, int image_height, char* output,
int max_output_chars, GemmaTokenCallback callback,
void* user_data);
// Returns number of tokens in text, or -1 on error
int CountTokens(const char* text);
// Add new method to set logger
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
s_log_callback = callback;
s_log_user_data = user_data;
}
// Set max generated tokens
void SetMaxGeneratedTokens(size_t value) {
inference_args.max_generated_tokens = value;
LogDebug("Setting max_generated_tokens to configured value");
}
// Set multiturn flag (0 = disabled, 1 = enabled)
void SetMultiturn(int value) {
inference_args.multiturn = value;
LogDebug("Setting multiturn to configured value");
}
// Set temperature for token generation
void SetTemperature(float value) {
inference_args.temperature = value;
LogDebug("Setting temperature to configured value");
}
// Set top_k parameter for sampling
void SetTopK(int value) {
inference_args.top_k = value;
LogDebug("Setting top_k to configured value");
}
// Set deterministic flag
void SetDeterministic(bool value) {
inference_args.deterministic = value;
LogDebug("Setting deterministic flag to configured value");
}
// Set prefill_tbatch_size
void SetPrefillTbatchSize(size_t value) {
inference_args.prefill_tbatch_size = value;
LogDebug("Setting prefill_tbatch_size to configured value");
}
void SaveConversation() {
if (!active_conversation || active_conversation_name.empty()) {
if (!active_conversation) {
LogDebug("SaveConversation: No active conversation to save.");
} else { // active_conversation_name must be empty
LogDebug(
"SaveConversation: Active conversation name is empty. Cannot "
"save.");
}
return;
}
std::string log_msg = "SaveConversation: Attempting to save '";
log_msg += active_conversation_name;
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
// Create a deep copy of the active_conversation via copy ctor.
auto conversation_copy =
std::make_shared<ConversationData>(*active_conversation);
// Store the deep copy in prewarmed_cache.
// If a conversation with the same name already exists, it will be
// overwritten. std::shared_ptr will handle the destruction of the old
// object if it's being replaced.
prewarmed_cache[active_conversation_name] = conversation_copy;
log_msg = "SaveConversation: Successfully saved '";
log_msg += active_conversation_name;
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
}
// Reset the currently active conversation
void ResetConversation() {
if (active_conversation) {
std::string log_prefix = "ResetConversation ('";
log_prefix += active_conversation_name.empty() ? "[unnamed]"
: active_conversation_name;
log_prefix += "'): ";
LogDebug((log_prefix + "Attempting to reset.").c_str());
// Attempt to restore from prewarmed_cache first, regardless of name.
auto it = prewarmed_cache.find(active_conversation_name);
if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) {
// Found in prewarmed_cache and the cached entry is valid.
LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.")
.c_str());
active_conversation->abs_pos = it->second->abs_pos;
// Perform a deep copy of the KVCache from the prewarmed version.
active_conversation->kv_cache =
std::make_unique<KVCache>(it->second->kv_cache->Copy());
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
.c_str());
return;
}
// If not found in prewarmed_cache or prewarmed_cache entry is invalid,
// rewind to initial state.
active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(
model.Config(), inference_args, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else {
LogDebug("Cannot reset conversation: active_conversation is null");
}
}
// Create a new named conversation
bool CreateConversation(const char* conversation_name) {
std::string name(conversation_name);
if (conversation_cache.count(name)) {
LogDebug("Conversation already exists");
return false;
}
LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>(
model.Config(), inference_args, ctx.allocator);
return true;
}
// Switch to a named conversation
bool SwitchConversation(const char* conversation_name) {
std::string name(conversation_name);
auto it = conversation_cache.find(name);
if (it == conversation_cache.end()) {
LogDebug("Conversation not found");
return false;
}
LogDebug("Switching active conversation");
active_conversation = it->second;
active_conversation_name = conversation_name;
return true;
}
// Delete a named conversation
bool DeleteConversation(const char* conversation_name) {
std::string name(conversation_name);
auto it = conversation_cache.find(name);
if (it == conversation_cache.end()) {
LogDebug("Conversation not found for deletion");
return false;
}
if (name == "default") {
LogDebug("Cannot delete the default conversation");
return false;
}
if (it->second == active_conversation) {
LogDebug("Cannot delete the currently active conversation");
return false;
}
LogDebug("Deleting conversation");
conversation_cache.erase(it);
auto it2 = prewarmed_cache.find(name);
if (it2 != prewarmed_cache.end()) {
prewarmed_cache.erase(it2);
}
return true;
}
// Check if a named conversation exists
bool HasConversation(const char* conversation_name) {
std::string name(conversation_name);
return conversation_cache.count(name);
}
// Get the name of the currently active conversation
const char* GetCurrentConversation();
private:
// Internal implementation shared by Generate and GenerateMultimodal
int GenerateInternal(const char* prompt_string,
const void* image_data, // Null for text-only generation
int image_width,
int image_height,
char* output, int max_output_chars,
GemmaTokenCallback callback, void* user_data);
// Pointer to the currently active conversation's data
std::shared_ptr<ConversationData> active_conversation;
// Cache of all named conversations
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
conversation_cache;
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
prewarmed_cache;
// Buffers (potentially could be moved into ConversationData if needed
// per-conversation)
std::string prompt_buffer;
std::string result_buffer;
std::vector<int> token_buffer;
// Cached args (remain global for the context)
InferenceArgs inference_args;
ThreadingArgs threading_args;
ThreadingContext ctx;
MatMulEnv matmul_env;
std::string active_conversation_name;
// Model itself (don't move this, needs to be below the args above)
Gemma model;
// Static members for logging
static GemmaLogCallback s_log_callback;
static void* s_log_user_data;
// Use logging helper method to print messages into a managed callback if
// necessary
static void LogDebug(const char* message) {
if (s_log_callback != nullptr) {
s_log_callback(message, s_log_user_data);
} else {
#ifdef _WIN32
OutputDebugStringA(message);
#else
printf("%s", message);
#endif
}
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_