mirror of https://github.com/google/gemma.cpp.git
Pass most runtime parameters using const RuntimeConfig&
PiperOrigin-RevId: 633572507
This commit is contained in:
parent
f1eab987d8
commit
eb0b96e0a8
|
|
@ -61,10 +61,17 @@ std::pair<std::string, int> QueryModel(
|
||||||
<< args.temperature;
|
<< args.temperature;
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
.max_tokens = args.max_tokens,
|
||||||
stream_token, accept_token, gen, app.verbosity, timing_info,
|
.max_generated_tokens = args.max_generated_tokens,
|
||||||
layers_output);
|
.temperature = args.temperature,
|
||||||
|
.verbosity = app.verbosity,
|
||||||
|
.gen = &gen,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token = accept_token,
|
||||||
|
};
|
||||||
|
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||||
|
timing_info, layers_output);
|
||||||
return {res, total_tokens};
|
return {res, total_tokens};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility> // std::pair
|
#include <utility> // std::pair
|
||||||
|
|
@ -88,9 +89,17 @@ std::pair<std::string, int> QueryModel(
|
||||||
<< args.temperature;
|
<< args.temperature;
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info;
|
gcpp::TimingInfo timing_info;
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
.max_tokens = args.max_tokens,
|
||||||
stream_token, accept_token, gen, app.verbosity, timing_info);
|
.max_generated_tokens = args.max_generated_tokens,
|
||||||
|
.temperature = args.temperature,
|
||||||
|
.verbosity = app.verbosity,
|
||||||
|
.gen = &gen,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token = accept_token,
|
||||||
|
};
|
||||||
|
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||||
|
timing_info, /*layers_output=*/nullptr);
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
161
gemma/gemma.cc
161
gemma/gemma.cc
|
|
@ -453,12 +453,10 @@ struct GemmaInterface {
|
||||||
|
|
||||||
virtual const GemmaTokenizer* Tokenizer() const = 0;
|
virtual const GemmaTokenizer* Tokenizer() const = 0;
|
||||||
|
|
||||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
virtual void Generate(const RuntimeConfig& runtime_config,
|
||||||
float temperature, const std::vector<int>& prompt,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
size_t start_pos, KVCache& kv_cache,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
TimingInfo& timing_info,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
|
||||||
int verbosity, TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output) = 0;
|
LayersOutputT* layers_output) = 0;
|
||||||
|
|
||||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||||
|
|
@ -547,12 +545,10 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
|
|
||||||
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
|
const GemmaTokenizer* Tokenizer() const override { return &tokenizer; }
|
||||||
|
|
||||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
void Generate(const RuntimeConfig& runtime_config,
|
||||||
float temperature, const std::vector<int>& prompt,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
TimingInfo& timing_info, LayersOutputT* layers_output) override;
|
||||||
std::mt19937&, int verbosity, TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output) override;
|
|
||||||
|
|
||||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
|
@ -1083,12 +1079,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
void GenerateImpl(GemmaImpl<TConfig>& gemma,
|
||||||
size_t max_generated_tokens, float temperature,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
|
||||||
int verbosity, TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||||
|
|
@ -1099,6 +1093,8 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
|
*reinterpret_cast<WeightsT<TConfig>*>(gemma.weights_u8.get());
|
||||||
|
|
||||||
size_t prompt_size = prompt.size();
|
size_t prompt_size = prompt.size();
|
||||||
|
size_t max_tokens = runtime_config.max_tokens;
|
||||||
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
RangeChecks<TConfig>(max_tokens, max_generated_tokens, prompt_size);
|
||||||
if (pos >= max_tokens) {
|
if (pos >= max_tokens) {
|
||||||
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos,
|
||||||
|
|
@ -1132,13 +1128,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||||
prefill_activations, kv_cache, pool);
|
prefill_activations, kv_cache, pool);
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
if (!stream_token(batch_tokens[idx], 0.0f)) return;
|
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
|
||||||
}
|
}
|
||||||
pos += batch_size;
|
pos += batch_size;
|
||||||
pos_offset += batch_size;
|
pos_offset += batch_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (verbosity >= 2) {
|
if (runtime_config.verbosity >= 2) {
|
||||||
const double prefill_end = hwy::platform::Now();
|
const double prefill_end = hwy::platform::Now();
|
||||||
timing_info.prefill_tok_sec =
|
timing_info.prefill_tok_sec =
|
||||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||||
|
|
@ -1150,7 +1146,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
|
|
||||||
size_t pos_gen_start = pos_offset;
|
size_t pos_gen_start = pos_offset;
|
||||||
int token = prompt.at(pos_offset);
|
int token = prompt.at(pos_offset);
|
||||||
stream_token(token, 0);
|
runtime_config.stream_token(token, 0);
|
||||||
for (size_t generate_pos = 0;
|
for (size_t generate_pos = 0;
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
|
|
@ -1169,21 +1165,22 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
activations.even_odd.data(), activations.logits.data(), pool);
|
activations.even_odd.data(), activations.logits.data(), pool);
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
token = SampleTopK<TConfig::kTopK>(
|
||||||
gen, temperature, accept_token);
|
activations.logits.data(), kVocabSize, *runtime_config.gen,
|
||||||
if (!stream_token(token, activations.logits[token])) {
|
runtime_config.temperature, runtime_config.accept_token);
|
||||||
|
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
||||||
token = EOS_ID;
|
token = EOS_ID;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// We would take this branch if we were not doing Prefill but would
|
// We would take this branch if we were not doing Prefill but would
|
||||||
// process the tokens of the prompt one at a time.
|
// process the tokens of the prompt one at a time.
|
||||||
token = prompt.at(pos_offset + 1);
|
token = prompt.at(pos_offset + 1);
|
||||||
if (!stream_token(token, 0)) {
|
if (!runtime_config.stream_token(token, 0)) {
|
||||||
token = EOS_ID;
|
token = EOS_ID;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (token == EOS_ID) {
|
if (token == EOS_ID) {
|
||||||
if (verbosity >= 2) {
|
if (runtime_config.verbosity >= 2) {
|
||||||
const double gen_end = hwy::platform::Now();
|
const double gen_end = hwy::platform::Now();
|
||||||
timing_info.gen_tok_sec =
|
timing_info.gen_tok_sec =
|
||||||
static_cast<double>(pos_offset - pos_gen_start) /
|
static_cast<double>(pos_offset - pos_gen_start) /
|
||||||
|
|
@ -1259,41 +1256,31 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
|
|
||||||
#undef TOKEN
|
#undef TOKEN
|
||||||
|
|
||||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma,
|
||||||
size_t max_generated_tokens, float temperature,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||||
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
|
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||||
LayersOutputT* layers_output) {
|
timing_info, layers_output);
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
|
||||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
|
||||||
verbosity, timing_info, layers_output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma,
|
||||||
size_t max_generated_tokens, float temperature,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||||
std::mt19937& gen, int verbosity, TimingInfo& timing_info,
|
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||||
LayersOutputT* layers_output) {
|
timing_info, layers_output);
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
|
||||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
|
||||||
verbosity, timing_info, layers_output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||||
size_t max_generated_tokens, float temperature,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token,
|
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
GenerateImpl(gemma, runtime_config, prompt, start_pos, kv_cache, pool,
|
||||||
int verbosity, TimingInfo& timing_info,
|
timing_info, layers_output);
|
||||||
LayersOutputT* layers_output) {
|
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
|
||||||
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
|
||||||
verbosity, timing_info, layers_output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
|
|
@ -1553,41 +1540,38 @@ GemmaImpl<Config>::GemmaImpl(
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()) {}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
void GemmaImpl<ConfigGemma2B>::Generate(const RuntimeConfig& runtime_config,
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
const std::vector<int>& prompt,
|
||||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& pool,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
TimingInfo& timing_info,
|
||||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
|
||||||
layers_output);
|
layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
void GemmaImpl<ConfigGemma7B>::Generate(const RuntimeConfig& runtime_config,
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
const std::vector<int>& prompt,
|
||||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& pool,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
TimingInfo& timing_info,
|
||||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
|
||||||
layers_output);
|
layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGriffin2B>::Generate(
|
void GemmaImpl<ConfigGriffin2B>::Generate(const RuntimeConfig& runtime_config,
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
const std::vector<int>& prompt,
|
||||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& pool,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
TimingInfo& timing_info,
|
||||||
TimingInfo& timing_info, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, runtime_config, prompt, start_pos, kv_cache, pool, timing_info,
|
||||||
kv_cache, pool, stream_token, accept_token, gen, verbosity, timing_info,
|
|
||||||
layers_output);
|
layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1653,30 +1637,15 @@ Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
|
const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
|
||||||
|
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||||
float temperature, const std::vector<int>& prompt,
|
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
const StreamFunc& stream_token,
|
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
|
||||||
int verbosity, TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output) {
|
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
|
||||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
|
||||||
start_pos, kv_cache, pool, stream_token, accept_token,
|
|
||||||
gen, verbosity, timing_info, layers_output);
|
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, std::mt19937& gen,
|
TimingInfo& timing_info,
|
||||||
TimingInfo& timing_info) {
|
LayersOutputT* layers_output) {
|
||||||
GenerateGemma(
|
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
gemma.impl_->Generate(runtime_config, prompt, start_pos, kv_cache, pool,
|
||||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool,
|
timing_info, layers_output);
|
||||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
timing_info, /*layers_output=*/nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
|
|
|
||||||
|
|
@ -60,11 +60,19 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
Model& model, ModelTraining& training);
|
Model& model, ModelTraining& training);
|
||||||
|
|
||||||
|
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||||
|
// probability is 0.0f.
|
||||||
|
using StreamFunc = std::function<bool(int, float)>;
|
||||||
|
using AcceptFunc = std::function<bool(int)>;
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
float temperature;
|
float temperature;
|
||||||
int verbosity;
|
int verbosity;
|
||||||
|
std::mt19937* gen;
|
||||||
|
const StreamFunc& stream_token;
|
||||||
|
const AcceptFunc& accept_token;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GemmaInterface;
|
struct GemmaInterface;
|
||||||
|
|
@ -97,29 +105,14 @@ KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||||
size_t conv1d_cache_size, size_t rglru_cache_size);
|
size_t conv1d_cache_size, size_t rglru_cache_size);
|
||||||
|
|
||||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
// Bundle runtime parameters as RuntimeConfig
|
||||||
// probability is 0.0f.
|
|
||||||
using StreamFunc = std::function<bool(int, float)>;
|
|
||||||
using AcceptFunc = std::function<bool(int)>;
|
|
||||||
|
|
||||||
// layers_output is optional; if set - it will be called with the activations
|
// layers_output is optional; if set - it will be called with the activations
|
||||||
// output after applying each layer.
|
// output after applying each layer.
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||||
float temperature, const std::vector<int>& prompt,
|
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
const StreamFunc& stream_token,
|
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
|
||||||
int verbosity, TimingInfo& timing_info,
|
|
||||||
LayersOutputT* layers_output = nullptr);
|
|
||||||
|
|
||||||
// Convenience function for the common case:
|
|
||||||
// - Bundle runtime parameters as RuntimeConfig
|
|
||||||
// - All tokens accepted
|
|
||||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, std::mt19937& gen,
|
TimingInfo& timing_info,
|
||||||
int verbosity, TimingInfo& timing_info);
|
LayersOutputT* layers_output = nullptr);
|
||||||
|
|
||||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||||
|
|
|
||||||
|
|
@ -56,11 +56,18 @@ class GemmaTest : public ::testing::Test {
|
||||||
response.push_back(token);
|
response.push_back(token);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
gcpp::GenerateGemma(
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
.max_tokens = 3072,
|
||||||
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
.max_generated_tokens = 2048,
|
||||||
stream_token, /*accept=*/[](int) { return true; }, gen,
|
.temperature = 1.0,
|
||||||
/*verbosity=*/0);
|
.verbosity = 0,
|
||||||
|
.gen = &gen,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token = [](int) { return true; },
|
||||||
|
};
|
||||||
|
gcpp::TimingInfo timing_info;
|
||||||
|
gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0,
|
||||||
|
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
|
||||||
std::string response_text;
|
std::string response_text;
|
||||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||||
return response_text;
|
return response_text;
|
||||||
|
|
|
||||||
14
gemma/run.cc
14
gemma/run.cc
|
|
@ -208,9 +208,17 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
}
|
}
|
||||||
|
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
args.temperature, prompt, abs_pos, kv_cache, pool,
|
.max_tokens = args.max_tokens,
|
||||||
stream_token, accept_token, gen, verbosity, timing_info);
|
.max_generated_tokens = args.max_generated_tokens,
|
||||||
|
.temperature = args.temperature,
|
||||||
|
.verbosity = verbosity,
|
||||||
|
.gen = &gen,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token = accept_token,
|
||||||
|
};
|
||||||
|
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
|
||||||
|
timing_info);
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||||
<< "\n"
|
<< "\n"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue