mirror of https://github.com/google/gemma.cpp.git
Pass most runtime parameters using const RuntimeConfig&
PiperOrigin-RevId: 633572507
This commit is contained in:
parent
f1eab987d8
commit
eb0b96e0a8
|
|
@ -61,10 +61,17 @@ std::pair<std::string, int> QueryModel(
|
|||
<< args.temperature;
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
stream_token, accept_token, gen, app.verbosity, timing_info,
|
||||
layers_output);
|
||||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_tokens = args.max_tokens,
|
||||
.max_generated_tokens = args.max_generated_tokens,
|
||||
.temperature = args.temperature,
|
||||
.verbosity = app.verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility> // std::pair
|
||||
|
|
@ -88,9 +89,17 @@ std::pair<std::string, int> QueryModel(
|
|||
<< args.temperature;
|
||||
}
|
||||
gcpp::TimingInfo timing_info;
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||
stream_token, accept_token, gen, app.verbosity, timing_info);
|
||||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_tokens = args.max_tokens,
|
||||
.max_generated_tokens = args.max_generated_tokens,
|
||||
.temperature = args.temperature,
|
||||
.verbosity = app.verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
timing_info, /*layers_output=*/nullptr);
|
||||
if (app.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
|
|
|
|||
161
gemma/gemma.cc
161
gemma/gemma.cc
|
|
@ -453,12 +453,10 @@ struct GemmaInterface {
|
|||
|
||||
virtual const GemmaTokenizer* Tokenizer() const = 0;
|
||||
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
virtual void Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) = 0;
|
||||
|
||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||
|
|
@ -547,12 +545,10 @@ struct GemmaImpl : public GemmaInterface {
|
|||
|
||||
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
|
||||
|
||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937&, int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) override;
|
||||
void Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) override;
|
||||
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
|
|
@ -1083,12 +1079,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||
|
|
@ -1099,6 +1093,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||
|
||||
size_t prompt_size = prompt.size();
|
||||
size_t max_tokens = runtime_config.max_tokens;
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
||||
if (pos >= max_tokens) {
|
||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
||||
|
|
@ -1132,13 +1128,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||
prefill_activations, kv_cache, pool);
|
||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||
if (!stream_token(batch_tokens[idx], 0.0f)) return;
|
||||
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
|
||||
}
|
||||
pos += batch_size;
|
||||
pos_offset += batch_size;
|
||||
}
|
||||
|
||||
if (verbosity >= 2) {
|
||||
if (runtime_config.verbosity >= 2) {
|
||||
const double prefill_end = hwy::platform::Now();
|
||||
timing_info.prefill_tok_sec =
|
||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||
|
|
@ -1150,7 +1146,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
|
||||
size_t pos_gen_start = pos_offset;
|
||||
int token = prompt.at(pos_offset);
|
||||
stream_token(token, 0);
|
||||
runtime_config.stream_token(token, 0);
|
||||
for (size_t generate_pos = 0;
|
||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||
++pos, ++pos_offset, ++generate_pos) {
|
||||
|
|
@ -1169,21 +1165,22 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
activations.even_odd.data(), activations.logits.data(), pool);
|
||||
// Barrier: must have all logits so we can subtract max.
|
||||
Softmax(activations.logits.data(), kVocabSize);
|
||||
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
||||
gen, temperature, accept_token);
|
||||
if (!stream_token(token, activations.logits[token])) {
|
||||
token = SampleTopK<TConfig::kTopK>(
|
||||
activations.logits.data(), kVocabSize, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
||||
token = EOS_ID;
|
||||
}
|
||||
} else {
|
||||
// We would take this branch if we were not doing Prefill but would
|
||||
// process the tokens of the prompt one at a time.
|
||||
token = prompt.at(pos_offset + 1);
|
||||
if (!stream_token(token, 0)) {
|
||||
if (!runtime_config.stream_token(token, 0)) {
|
||||
token = EOS_ID;
|
||||
}
|
||||
}
|
||||
if (token == EOS_ID) {
|
||||
if (verbosity >= 2) {
|
||||
if (runtime_config.verbosity >= 2) {
|
||||
const double gen_end = hwy::platform::Now();
|
||||
timing_info.gen_tok_sec =
|
||||
static_cast<double>(pos_offset - pos_gen_start) /
|
||||
|
|
@ -1259,41 +1256,31 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
|
||||
#undef TOKEN
|
||||
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, timing_info, layers_output);
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, timing_info, layers_output);
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
}
|
||||
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||
verbosity, timing_info, layers_output);
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
}
|
||||
|
||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
|
|
@ -1553,41 +1540,38 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGriffin2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
||||
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
|
|
@ -1653,30 +1637,15 @@ Gemma::~Gemma() = default; // after GemmaInterface is defined
|
|||
|
||||
const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
|
||||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, kv_cache, pool, stream_token, accept_token,
|
||||
gen, verbosity, timing_info, layers_output);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen,
|
||||
TimingInfo& timing_info) {
|
||||
GenerateGemma(
|
||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool,
|
||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
||||
timing_info, /*layers_output=*/nullptr);
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(runtime_config, prompt, start_pos, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
|
|
|
|||
|
|
@ -60,11 +60,19 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
|||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
float temperature;
|
||||
int verbosity;
|
||||
std::mt19937* gen;
|
||||
const StreamFunc& stream_token;
|
||||
const AcceptFunc& accept_token;
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
|
@ -97,29 +105,14 @@ KVCache CreateKVCache(Model type); // convenient workaround for now
|
|||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
// Bundle runtime parameters as RuntimeConfig
|
||||
// layers_output is optional; if set - it will be called with the activations
|
||||
// output after applying each layer.
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
// Convenience function for the common case:
|
||||
// - Bundle runtime parameters as RuntimeConfig
|
||||
// - All tokens accepted
|
||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen,
|
||||
int verbosity, TimingInfo& timing_info);
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||
|
|
|
|||
|
|
@ -56,11 +56,18 @@ class GemmaTest : public ::testing::Test {
|
|||
response.push_back(token);
|
||||
return true;
|
||||
};
|
||||
gcpp::GenerateGemma(
|
||||
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
||||
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
stream_token, /*accept=*/[](int) { return true; }, gen,
|
||||
/*verbosity=*/0);
|
||||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_tokens = 3072,
|
||||
.max_generated_tokens = 2048,
|
||||
.temperature = 1.0,
|
||||
.verbosity = 0,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = [](int) { return true; },
|
||||
};
|
||||
gcpp::TimingInfo timing_info;
|
||||
gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0,
|
||||
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
|
||||
std::string response_text;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||
return response_text;
|
||||
|
|
|
|||
14
gemma/run.cc
14
gemma/run.cc
|
|
@ -208,9 +208,17 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
}
|
||||
|
||||
TimingInfo timing_info;
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool,
|
||||
stream_token, accept_token, gen, verbosity, timing_info);
|
||||
gcpp::RuntimeConfig runtime_config = {
|
||||
.max_tokens = args.max_tokens,
|
||||
.max_generated_tokens = args.max_generated_tokens,
|
||||
.temperature = args.temperature,
|
||||
.verbosity = verbosity,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
|
||||
timing_info);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
<< "\n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue