mirror of https://github.com/google/gemma.cpp.git
Add an additional QueryModel() overload to GemmaEnv.
Use args only in GemmaEnv constructor, store everything else in RuntimeConfig. Add runtime option to turn off thread spinning. PiperOrigin-RevId: 670467320
This commit is contained in:
parent
f6abbab3a4
commit
a8e08778d4
|
|
@ -129,7 +129,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache = KVCache::Create(
|
KVCache kv_cache = KVCache::Create(
|
||||||
env.Info().model, env.MutableInferenceArgs().prefill_tbatch_size);
|
env.GetModel()->Info().model, env.MutableConfig().prefill_tbatch_size);
|
||||||
float entropy = ComputeCrossEntropy(
|
float entropy = ComputeCrossEntropy(
|
||||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
|
|
@ -186,7 +186,9 @@ int main(int argc, char** argv) {
|
||||||
if (!benchmark_args.goldens.Empty()) {
|
if (!benchmark_args.goldens.Empty()) {
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" +
|
benchmark_args.goldens.path + "/" +
|
||||||
gcpp::ModelString(env.Info().model, env.Info().training) + ".txt";
|
gcpp::ModelString(env.GetModel()->Info().model,
|
||||||
|
env.GetModel()->Info().training) +
|
||||||
|
".txt";
|
||||||
return BenchmarkGoldens(env, golden_path);
|
return BenchmarkGoldens(env, golden_path);
|
||||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||||
|
|
|
||||||
|
|
@ -58,32 +58,27 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
const AppArgs& app)
|
const AppArgs& app)
|
||||||
: loader_(loader),
|
: pools_(app.max_clusters, app.num_threads, app.pin) {
|
||||||
inference_args_(inference),
|
InferenceArgs mutable_inference = inference;
|
||||||
app_(app),
|
AbortIfInvalidArgs(mutable_inference);
|
||||||
pools_(app_.max_clusters, app_.num_threads) {
|
LoaderArgs mutable_loader = loader;
|
||||||
AbortIfInvalidArgs(inference_args_);
|
if (const char* err = mutable_loader.Validate()) {
|
||||||
|
mutable_loader.Help();
|
||||||
if (const char* err = loader_.Validate()) {
|
|
||||||
loader_.Help();
|
|
||||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Loading model...\n");
|
fprintf(stderr, "Loading model...\n");
|
||||||
model_ = AllocateGemma(loader_, pools_);
|
model_ = AllocateGemma(mutable_loader, pools_);
|
||||||
|
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.resize(1);
|
kv_caches_.resize(1);
|
||||||
kv_caches_[0] =
|
kv_caches_[0] =
|
||||||
KVCache::Create(model_->Info().model, inference.prefill_tbatch_size);
|
KVCache::Create(model_->Info().model, inference.prefill_tbatch_size);
|
||||||
}
|
}
|
||||||
|
InitGenerator(inference, gen_);
|
||||||
InitGenerator(inference_args_, gen_);
|
|
||||||
|
|
||||||
runtime_config_ = {
|
runtime_config_ = {
|
||||||
.max_tokens = inference_args_.max_tokens,
|
.max_tokens = inference.max_tokens,
|
||||||
.max_generated_tokens = inference_args_.max_generated_tokens,
|
.max_generated_tokens = inference.max_generated_tokens,
|
||||||
.temperature = inference_args_.temperature,
|
.temperature = inference.temperature,
|
||||||
.verbosity = app_.verbosity,
|
.verbosity = app.verbosity,
|
||||||
.gen = &gen_,
|
.gen = &gen_,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -115,20 +110,30 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
|
||||||
res += StringFromTokens(std::vector<int>{token});
|
res += StringFromTokens(std::vector<int>{token});
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
if (app_.verbosity >= 2) {
|
if (runtime_config_.verbosity >= 2) {
|
||||||
std::cout << "Max tokens: " << inference_args_.max_tokens
|
std::cout << "Max tokens: " << runtime_config_.max_tokens
|
||||||
<< "\tmax generated tokens: "
|
<< "\tmax generated tokens: "
|
||||||
<< inference_args_.max_generated_tokens
|
<< runtime_config_.max_generated_tokens
|
||||||
<< "\ttemperature: " << inference_args_.temperature << "\n";
|
<< "\ttemperature: " << runtime_config_.temperature << "\n";
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info { .verbosity = app_.verbosity };
|
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
timing_info);
|
timing_info);
|
||||||
return {res, total_tokens};
|
return {res, total_tokens};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
void GemmaEnv::QueryModel(
|
||||||
|
const std::vector<int>& tokens, const StreamFunc& stream_token) {
|
||||||
|
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||||
|
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||||
|
runtime_config_.stream_token = stream_token;
|
||||||
|
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
|
timing_info);
|
||||||
|
runtime_config_.stream_token = previous_stream_token;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt) {
|
const QueriesPromptTokens& queries_prompt) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_ASSERT(num_queries != 0);
|
HWY_ASSERT(num_queries != 0);
|
||||||
|
|
@ -144,12 +149,12 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
res[query_index].second += 1;
|
res[query_index].second += 1;
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
if (app_.verbosity >= 2) {
|
if (runtime_config_.verbosity >= 2) {
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||||
inference_args_.max_tokens, inference_args_.max_generated_tokens,
|
runtime_config_.max_tokens, runtime_config_.max_generated_tokens,
|
||||||
inference_args_.temperature, inference_args_.prefill_tbatch_size,
|
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
|
||||||
inference_args_.decode_qbatch_size);
|
runtime_config_.decode_qbatch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have one KVCache per query.
|
// Ensure we have one KVCache per query.
|
||||||
|
|
@ -159,13 +164,12 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
for (size_t i = 1; i < num_queries; ++i) {
|
for (size_t i = 1; i < num_queries; ++i) {
|
||||||
if (kv_caches_[i].seq_len == 0) {
|
if (kv_caches_[i].seq_len == 0) {
|
||||||
kv_caches_[i] = KVCache::Create(model_->Info().model,
|
kv_caches_[i] = KVCache::Create(model_->Info().model,
|
||||||
inference_args_.prefill_tbatch_size);
|
runtime_config_.prefill_tbatch_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity};
|
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
inference_args_.CopyTo(runtime_config_);
|
|
||||||
std::vector<size_t> queries_pos(num_queries, 0);
|
std::vector<size_t> queries_pos(num_queries, 0);
|
||||||
model_->GenerateBatch(runtime_config_, queries_prompt,
|
model_->GenerateBatch(runtime_config_, queries_prompt,
|
||||||
QueriesPos(queries_pos.data(), num_queries),
|
QueriesPos(queries_pos.data(), num_queries),
|
||||||
|
|
@ -174,7 +178,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
||||||
const std::vector<int> prompt = WrapAndTokenize(model_->Tokenizer(), Info(),
|
const std::vector<int> prompt =
|
||||||
|
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||||
/*pos=*/0, input);
|
/*pos=*/0, input);
|
||||||
return QueryModel(prompt);
|
return QueryModel(prompt);
|
||||||
}
|
}
|
||||||
|
|
@ -194,7 +199,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||||
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||||
return BatchQueryModel2(prompt_span);
|
return BatchQueryModel(prompt_span);
|
||||||
}
|
}
|
||||||
|
|
||||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||||
|
|
|
||||||
|
|
@ -41,10 +41,10 @@ class GemmaEnv {
|
||||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
const AppArgs& app);
|
const AppArgs& app);
|
||||||
|
|
||||||
size_t MaxTokens() const { return inference_args_.max_tokens; }
|
size_t MaxTokens() const { return runtime_config_.max_tokens; }
|
||||||
// Sets the maximum number of output tokens to generate.
|
// Sets the maximum number of output tokens to generate.
|
||||||
void SetMaxGeneratedTokens(size_t max_tokens) {
|
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
|
||||||
inference_args_.max_generated_tokens = max_tokens;
|
runtime_config_.max_generated_tokens = max_generated_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> Tokenize(const std::string& input) const {
|
std::vector<int> Tokenize(const std::string& input) const {
|
||||||
|
|
@ -68,13 +68,17 @@ class GemmaEnv {
|
||||||
// Runs inference on the given input and returns the top-1 result string and
|
// Runs inference on the given input and returns the top-1 result string and
|
||||||
// the number of tokens that were generated.
|
// the number of tokens that were generated.
|
||||||
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
|
||||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
|
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt);
|
const QueriesPromptTokens& queries_prompt);
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
std::pair<std::string, size_t> QueryModel(std::string& input);
|
std::pair<std::string, size_t> QueryModel(std::string& input);
|
||||||
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs);
|
const std::vector<std::string>& inputs);
|
||||||
|
|
||||||
|
// Runs inference on the given input and calls the callback for each token.
|
||||||
|
void QueryModel(const std::vector<int>& tokens,
|
||||||
|
const StreamFunc& stream_token);
|
||||||
|
|
||||||
// Runs inference on the given input and returns the cross entropy, a measure
|
// Runs inference on the given input and returns the cross entropy, a measure
|
||||||
// of how well the model predicts the correct output. It is the average
|
// of how well the model predicts the correct output. It is the average
|
||||||
// number of bits per token.
|
// number of bits per token.
|
||||||
|
|
@ -83,20 +87,12 @@ class GemmaEnv {
|
||||||
// Returns nullptr if the model failed to load.
|
// Returns nullptr if the model failed to load.
|
||||||
Gemma* GetModel() const { return model_.get(); }
|
Gemma* GetModel() const { return model_.get(); }
|
||||||
|
|
||||||
int Verbosity() const { return app_.verbosity; }
|
int Verbosity() const { return runtime_config_.verbosity; }
|
||||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||||
const ModelInfo& Info() const { return loader_.Info(); }
|
|
||||||
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
|
||||||
std::mt19937& MutableGen() { return gen_; }
|
std::mt19937& MutableGen() { return gen_; }
|
||||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Arguments to the model loader: file locations, etc.
|
|
||||||
LoaderArgs loader_;
|
|
||||||
// Arguments to the inference function: max tokens, etc.
|
|
||||||
InferenceArgs inference_args_;
|
|
||||||
// Controls overall behavior of the app.
|
|
||||||
AppArgs app_;
|
|
||||||
// Thread pool for running inference.
|
// Thread pool for running inference.
|
||||||
PerClusterPools pools_;
|
PerClusterPools pools_;
|
||||||
// Random number generator.
|
// Random number generator.
|
||||||
|
|
@ -105,6 +101,7 @@ class GemmaEnv {
|
||||||
std::unique_ptr<Gemma> model_;
|
std::unique_ptr<Gemma> model_;
|
||||||
// KV caches, same number as query batch.
|
// KV caches, same number as query batch.
|
||||||
std::vector<KVCache> kv_caches_;
|
std::vector<KVCache> kv_caches_;
|
||||||
|
// Runtime config for inference.
|
||||||
RuntimeConfig runtime_config_;
|
RuntimeConfig runtime_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||||
for (auto [response, n] : s_env->BatchQueryModel2(prompts)) {
|
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
|
||||||
replies.push_back(response);
|
replies.push_back(response);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -116,7 +116,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GemmaTest, GeographyBatched) {
|
TEST_F(GemmaTest, GeographyBatched) {
|
||||||
s_env->MutableInferenceArgs().decode_qbatch_size = 3;
|
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||||
// 6 are enough to test batching and the loop.
|
// 6 are enough to test batching and the loop.
|
||||||
static const char* kQA[][2] = {
|
static const char* kQA[][2] = {
|
||||||
{"What is the capital of Australia?", "Canberra"},
|
{"What is the capital of Australia?", "Canberra"},
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
"Do not include any justifications or explanations. Reply only with a "
|
"Do not include any justifications or explanations. Reply only with a "
|
||||||
"letter.";
|
"letter.";
|
||||||
const std::vector<int> prompt =
|
const std::vector<int> prompt =
|
||||||
WrapAndTokenize(env.GetModel()->Tokenizer(), env.Info(),
|
WrapAndTokenize(env.GetModel()->Tokenizer(), env.GetModel()->Info(),
|
||||||
/*pos=*/0, prompt_string);
|
/*pos=*/0, prompt_string);
|
||||||
const size_t prompt_size = prompt.size();
|
const size_t prompt_size = prompt.size();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,26 +98,26 @@ struct GenerateBatchT {
|
||||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||||
TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
pools_.StartSpinning();
|
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
||||||
runtime_config, prompt, pos, kv_cache,
|
runtime_config, prompt, pos, kv_cache,
|
||||||
pools_, timing_info);
|
pools_, timing_info);
|
||||||
|
|
||||||
pools_.StopSpinning();
|
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
const QueriesPromptTokens& queries_prompt,
|
const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesPos& queries_pos,
|
const QueriesPos& queries_pos,
|
||||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||||
pools_.StartSpinning();
|
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||||
|
|
||||||
CallForModelAndWeight<GenerateBatchT>(
|
CallForModelAndWeight<GenerateBatchT>(
|
||||||
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt,
|
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt,
|
||||||
queries_pos, kv_caches, pools_, timing_info);
|
queries_pos, kv_caches, pools_, timing_info);
|
||||||
|
|
||||||
pools_.StopSpinning();
|
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
@ -74,7 +75,10 @@ using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||||
using ActivationsObserverFunc =
|
using ActivationsObserverFunc =
|
||||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||||
|
|
||||||
|
// RuntimeConfig holds configuration for a single generation run.
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
|
// If not empty, batch_stream_token is called for each token in the batch,
|
||||||
|
// instead of stream_token.
|
||||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||||
if (batch_stream_token) {
|
if (batch_stream_token) {
|
||||||
return batch_stream_token(query_idx, pos, token, prob);
|
return batch_stream_token(query_idx, pos, token, prob);
|
||||||
|
|
@ -82,6 +86,7 @@ struct RuntimeConfig {
|
||||||
return stream_token(token, prob);
|
return stream_token(token, prob);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Limits on the number of tokens generated.
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
|
|
@ -91,15 +96,24 @@ struct RuntimeConfig {
|
||||||
// Max queries per batch (one token from each) during decode.
|
// Max queries per batch (one token from each) during decode.
|
||||||
size_t decode_qbatch_size = 16;
|
size_t decode_qbatch_size = 16;
|
||||||
|
|
||||||
float temperature;
|
float temperature; // Temperature for sampling.
|
||||||
int verbosity;
|
int verbosity; // Controls verbosity of printed messages.
|
||||||
std::mt19937* gen;
|
std::mt19937* gen; // Random number generator used for sampling.
|
||||||
|
|
||||||
|
// Functions operating on the generated tokens.
|
||||||
StreamFunc stream_token;
|
StreamFunc stream_token;
|
||||||
BatchStreamFunc batch_stream_token;
|
BatchStreamFunc batch_stream_token;
|
||||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||||
|
|
||||||
|
// Observer callbacks for intermediate data.
|
||||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||||
ActivationsObserverFunc activations_observer; // if set, called per-layer
|
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
||||||
|
|
||||||
|
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||||
|
bool use_spinning = true;
|
||||||
|
|
||||||
|
// End-of-sequence token.
|
||||||
int eos_id = EOS_ID;
|
int eos_id = EOS_ID;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ static inline const char* CompiledConfig() {
|
||||||
class AppArgs : public ArgsBase<AppArgs> {
|
class AppArgs : public ArgsBase<AppArgs> {
|
||||||
public:
|
public:
|
||||||
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
AppArgs() { Init(); };
|
||||||
|
|
||||||
int verbosity;
|
int verbosity;
|
||||||
|
|
||||||
|
|
@ -88,6 +89,13 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||||
|
const std::string& model) {
|
||||||
|
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||||
|
tokenizer.path = tokenizer_path;
|
||||||
|
weights.path = weights_path;
|
||||||
|
model_type_str = model;
|
||||||
|
};
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
|
|
@ -168,6 +176,7 @@ static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
InferenceArgs() { Init(); };
|
||||||
|
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue