mirror of https://github.com/google/gemma.cpp.git
Simplify threading: remove the use of inner_pool.
We only used inner_pool in the prefill FFW function, and there we
can achieve sufficient parallelism on the rows of the matrix-vector
multiplications.
Benchmark results on a 1600-token summarization task:
```
Prefill speed
Num threads BEFORE AFTER
4 9.24 t/s 9.76 t/s
18 31.41 t/s 31.16 t/s
32 31.41 t/s 45.13 t/s
64 31.03 t/s 57.85 t/s
```
This commit is contained in:
parent
1d18c5a129
commit
27117cc39f
|
|
@ -27,8 +27,8 @@ class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
|
||||||
|
|
||||||
std::pair<std::string, int> QueryModel(
|
std::pair<std::string, int> QueryModel(
|
||||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
|
||||||
const std::string& input, gcpp::LayersOutputT* layers_output) {
|
gcpp::LayersOutputT* layers_output) {
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
||||||
|
|
@ -55,8 +55,7 @@ std::pair<std::string, int> QueryModel(
|
||||||
}
|
}
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||||
inner_pool, stream_token, accept_token, gen, app.verbosity,
|
stream_token, accept_token, gen, app.verbosity, layers_output);
|
||||||
layers_output);
|
|
||||||
return {res, total_tokens};
|
return {res, total_tokens};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -92,7 +91,6 @@ int main(int argc, char** argv) {
|
||||||
gcpp::LayersOutputT* layers_output =
|
gcpp::LayersOutputT* layers_output =
|
||||||
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
|
log_layers_output ? &json_logger.layers_output_log_f : nullptr;
|
||||||
|
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
// For many-core, pinning threads to cores helps.
|
// For many-core, pinning threads to cores helps.
|
||||||
if (app.num_threads > 10) {
|
if (app.num_threads > 10) {
|
||||||
|
|
@ -112,7 +110,7 @@ int main(int argc, char** argv) {
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
const auto [answer, token_count] = QueryModel(
|
const auto [answer, token_count] = QueryModel(
|
||||||
model, args, app, kv_cache, inner_pool, pool, prompt, layers_output);
|
model, args, app, kv_cache, pool, prompt, layers_output);
|
||||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||||
|
|
||||||
if (log_layers_output) {
|
if (log_layers_output) {
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) {
|
||||||
|
|
||||||
std::pair<std::string, int> QueryModel(
|
std::pair<std::string, int> QueryModel(
|
||||||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
|
||||||
const std::string& input) {
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
||||||
|
|
@ -90,7 +89,7 @@ std::pair<std::string, int> QueryModel(
|
||||||
}
|
}
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
|
||||||
inner_pool, stream_token, accept_token, gen, app.verbosity);
|
stream_token, accept_token, gen, app.verbosity);
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
LogSpeedStats(time_start, total_tokens);
|
LogSpeedStats(time_start, total_tokens);
|
||||||
}
|
}
|
||||||
|
|
@ -131,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) {
|
||||||
|
|
||||||
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const std::string& golden_path) {
|
||||||
const std::string& golden_path) {
|
|
||||||
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
const std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||||
load_goldens(golden_path);
|
load_goldens(golden_path);
|
||||||
int correct_answers = 0;
|
int correct_answers = 0;
|
||||||
|
|
@ -140,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
for (const auto& [question, expected_answer] : queries_answers) {
|
for (const auto& [question, expected_answer] : queries_answers) {
|
||||||
const auto [answer, token_count] =
|
const auto [answer, token_count] =
|
||||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, question);
|
QueryModel(model, args, app, kv_cache, pool, question);
|
||||||
total_tokens += token_count;
|
total_tokens += token_count;
|
||||||
if (answer.find(expected_answer) != std::string::npos) {
|
if (answer.find(expected_answer) != std::string::npos) {
|
||||||
correct_answers++;
|
correct_answers++;
|
||||||
|
|
@ -164,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
|
||||||
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const gcpp::Path& text) {
|
||||||
const gcpp::Path& text) {
|
|
||||||
std::string prompt("Here is some text to summarize:\n");
|
std::string prompt("Here is some text to summarize:\n");
|
||||||
prompt.append(ReadFile(text));
|
prompt.append(ReadFile(text));
|
||||||
prompt.append("\nSummarize this text.\n");
|
prompt.append("\nSummarize this text.\n");
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
const auto [answer, token_count] =
|
const auto [answer, token_count] =
|
||||||
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt);
|
QueryModel(model, args, app, kv_cache, pool, prompt);
|
||||||
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
|
||||||
LogSpeedStats(time_start, token_count);
|
LogSpeedStats(time_start, token_count);
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
|
|
@ -179,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
|
|
||||||
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const gcpp::Path& text,
|
||||||
const gcpp::Path& text, size_t batch_tokens) {
|
size_t batch_tokens) {
|
||||||
std::string input = ReadFile(text);
|
std::string input = ReadFile(text);
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||||
|
|
@ -197,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
auto kv_cache = CreateKVCache(model_type);
|
auto kv_cache = CreateKVCache(model_type);
|
||||||
float entropy =
|
float entropy =
|
||||||
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
||||||
inner_pool, app.verbosity);
|
app.verbosity);
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
LogSpeedStats(time_start, pos + num_tokens);
|
LogSpeedStats(time_start, pos + num_tokens);
|
||||||
std::string text_slice;
|
std::string text_slice;
|
||||||
|
|
@ -211,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
||||||
|
|
||||||
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, const gcpp::Path& json_file,
|
||||||
const gcpp::Path& json_file, size_t max_questions) {
|
size_t max_questions) {
|
||||||
std::ifstream trivia_file(json_file.path);
|
std::ifstream trivia_file(json_file.path);
|
||||||
if (!trivia_file) {
|
if (!trivia_file) {
|
||||||
std::cout << "Could not load file: " << json_file.path << "\n"
|
std::cout << "Could not load file: " << json_file.path << "\n"
|
||||||
|
|
@ -225,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
|
||||||
while (std::getline(trivia_file, line)) {
|
while (std::getline(trivia_file, line)) {
|
||||||
json data = json::parse(line);
|
json data = json::parse(line);
|
||||||
const auto [answer, token_count] = QueryModel(
|
const auto [answer, token_count] = QueryModel(
|
||||||
model, args, app, kv_cache, inner_pool, pool, data["question"]);
|
model, args, app, kv_cache, pool, data["question"]);
|
||||||
std::cout << answer << "\n";
|
std::cout << answer << "\n";
|
||||||
bool correct = false;
|
bool correct = false;
|
||||||
for (const std::string expected : data["answer"]["aliases"]) {
|
for (const std::string expected : data["answer"]["aliases"]) {
|
||||||
|
|
@ -263,7 +260,6 @@ int main(int argc, char** argv) {
|
||||||
HWY_ABORT("\nInvalid inference args: %s", error);
|
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
// For many-core, pinning threads to cores helps.
|
// For many-core, pinning threads to cores helps.
|
||||||
if (app.num_threads > 10) {
|
if (app.num_threads > 10) {
|
||||||
|
|
@ -280,17 +276,16 @@ int main(int argc, char** argv) {
|
||||||
if (!benchmark_args.goldens.path.empty()) {
|
if (!benchmark_args.goldens.path.empty()) {
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
||||||
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
|
||||||
golden_path);
|
|
||||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||||
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkSummary(model, args, app, kv_cache, pool,
|
||||||
benchmark_args.summarize_text);
|
benchmark_args.summarize_text);
|
||||||
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
} else if (!benchmark_args.cross_entropy.path.empty()) {
|
||||||
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
|
||||||
inner_pool, pool, benchmark_args.cross_entropy,
|
pool, benchmark_args.cross_entropy,
|
||||||
benchmark_args.batch_tokens);
|
benchmark_args.batch_tokens);
|
||||||
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
} else if (!benchmark_args.trivia_qa.path.empty()) {
|
||||||
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
|
||||||
benchmark_args.trivia_qa,
|
benchmark_args.trivia_qa,
|
||||||
benchmark_args.max_questions);
|
benchmark_args.max_questions);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
125
gemma/gemma.cc
125
gemma/gemma.cc
|
|
@ -440,15 +440,13 @@ struct GemmaInterface {
|
||||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache,
|
size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||||
const StreamFunc& stream_token,
|
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity, LayersOutputT* layers_output) = 0;
|
int verbosity, LayersOutputT* layers_output) = 0;
|
||||||
|
|
||||||
virtual float ComputeCrossEntropy(size_t max_tokens,
|
virtual float ComputeCrossEntropy(size_t max_tokens,
|
||||||
const std::vector<int>& prompt,
|
const std::vector<int>& prompt,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool,
|
|
||||||
int verbosity) = 0;
|
int verbosity) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -535,13 +533,12 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity,
|
std::mt19937&, int verbosity,
|
||||||
LayersOutputT* layers_output) override;
|
LayersOutputT* layers_output) override;
|
||||||
|
|
||||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool,
|
|
||||||
int verbosity) override;
|
int verbosity) override;
|
||||||
|
|
||||||
GemmaTokenizerImpl tokenizer;
|
GemmaTokenizerImpl tokenizer;
|
||||||
|
|
@ -880,8 +877,7 @@ template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
||||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
const WeightArrayT& weights,
|
const WeightArrayT& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& inner_pool) {
|
|
||||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
|
|
@ -924,19 +920,17 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
// TODO: sink the loop into these functions, i.e. make them MatMul.
|
||||||
pool.Run(
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
0, num_tokens,
|
|
||||||
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR {
|
|
||||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
FFW<kBatchSize>(activations, token_idx, layer_weights, inner_pool);
|
FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
|
||||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
});
|
}
|
||||||
} // foreach layer
|
} // foreach layer
|
||||||
|
|
||||||
pool.Run(
|
pool.Run(
|
||||||
|
|
@ -950,8 +944,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
template <typename WeightArrayT, class TConfig>
|
template <typename WeightArrayT, class TConfig>
|
||||||
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, LayersOutputT* layers_output) {
|
||||||
LayersOutputT* layers_output) {
|
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
float token_f = token;
|
float token_f = token;
|
||||||
(*layers_output)(pos, "Tokens", &token_f, 1);
|
(*layers_output)(pos, "Tokens", &token_f, 1);
|
||||||
|
|
@ -1033,8 +1026,7 @@ template <class TConfig>
|
||||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||||
const StreamFunc& stream_token,
|
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity, LayersOutputT* layers_output) {
|
int verbosity, LayersOutputT* layers_output) {
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
|
|
@ -1077,7 +1069,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||||
const int* batch_tokens = prompt.data() + pos_offset;
|
const int* batch_tokens = prompt.data() + pos_offset;
|
||||||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||||
prefill_activations, kv_cache, pool, inner_pool);
|
prefill_activations, kv_cache, pool);
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
if (!stream_token(batch_tokens[idx], 0.0f)) return;
|
if (!stream_token(batch_tokens[idx], 0.0f)) return;
|
||||||
}
|
}
|
||||||
|
|
@ -1105,7 +1097,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
||||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
Transformer(token, pos, weights, activations, kv_cache, pool,
|
||||||
layers_output);
|
layers_output);
|
||||||
float* final_activation = activations.x.data();
|
float* final_activation = activations.x.data();
|
||||||
// The condition below is always true if we are doing Prefill above.
|
// The condition below is always true if we are doing Prefill above.
|
||||||
|
|
@ -1171,8 +1163,7 @@ void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool,
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
hwy::ThreadPool& inner_pool, int verbosity) {
|
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
Activations<TConfig, 1>& activations = *gemma.state.get();
|
Activations<TConfig, 1>& activations = *gemma.state.get();
|
||||||
|
|
@ -1196,7 +1187,7 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||||
total_entropy / std::log(2.0) / (pos + 1));
|
total_entropy / std::log(2.0) / (pos + 1));
|
||||||
}
|
}
|
||||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
Transformer(token, pos, weights, activations, kv_cache, pool,
|
||||||
/*layers_output=*/nullptr);
|
/*layers_output=*/nullptr);
|
||||||
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
|
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
|
||||||
activations.x.data(),
|
activations.x.data(),
|
||||||
|
|
@ -1215,62 +1206,59 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
std::mt19937& gen, int verbosity,
|
||||||
int verbosity, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||||
accept_token, gen, verbosity, layers_output);
|
verbosity, layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
std::mt19937& gen, int verbosity,
|
||||||
int verbosity, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||||
accept_token, gen, verbosity, layers_output);
|
verbosity, layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool,
|
|
||||||
const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity, LayersOutputT* layers_output) {
|
int verbosity, LayersOutputT* layers_output) {
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
start_pos, kv_cache, pool, stream_token, accept_token, gen,
|
||||||
accept_token, gen, verbosity, layers_output);
|
verbosity, layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
int verbosity) {
|
|
||||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||||
inner_pool, verbosity);
|
verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
int verbosity) {
|
|
||||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||||
inner_pool, verbosity);
|
verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||||
size_t max_tokens,
|
size_t max_tokens,
|
||||||
const std::vector<int>& prompt,
|
const std::vector<int>& prompt,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, int verbosity) {
|
int verbosity) {
|
||||||
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||||
inner_pool, verbosity);
|
verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||||
|
|
@ -1507,12 +1495,12 @@ template <>
|
||||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
kv_cache, pool, stream_token, accept_token, gen, verbosity,
|
||||||
layers_output);
|
layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1520,50 +1508,49 @@ template <>
|
||||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output);
|
||||||
layers_output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGriffin2B>::Generate(
|
void GemmaImpl<ConfigGriffin2B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, const StreamFunc& stream_token,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
|
||||||
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity,
|
kv_cache, pool, stream_token, accept_token, gen, verbosity,
|
||||||
layers_output);
|
layers_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
|
||||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
|
||||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
||||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
*this, max_tokens, prompt, kv_cache, pool, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
|
|
@ -1607,13 +1594,13 @@ const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity, LayersOutputT* layers_output) {
|
int verbosity, LayersOutputT* layers_output) {
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
start_pos, kv_cache, pool, stream_token, accept_token,
|
||||||
accept_token, gen, verbosity, layers_output);
|
gen, verbosity, layers_output);
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1621,10 +1608,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, std::mt19937& gen) {
|
const StreamFunc& stream_token, std::mt19937& gen) {
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
GenerateGemma(
|
GenerateGemma(
|
||||||
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
|
||||||
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool,
|
runtime_config.temperature, prompt, start_pos, kv_cache, pool,
|
||||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
|
||||||
/*layers_output=*/nullptr);
|
/*layers_output=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
@ -1637,11 +1623,10 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
|
|
||||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, int verbosity) {
|
||||||
int verbosity) {
|
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
const float result = gemma.impl_->ComputeCrossEntropy(
|
const float result = gemma.impl_->ComputeCrossEntropy(
|
||||||
max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
max_tokens, prompt, kv_cache, pool, verbosity);
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -104,13 +104,12 @@ using AcceptFunc = std::function<bool(int)>;
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity, LayersOutputT* layers_output = nullptr);
|
int verbosity, LayersOutputT* layers_output = nullptr);
|
||||||
|
|
||||||
// Convenience function for the common case:
|
// Convenience function for the common case:
|
||||||
// - Bundle runtime parameters as RuntimeConfig
|
// - Bundle runtime parameters as RuntimeConfig
|
||||||
// - No ThreadPool within ThreadPool (inner_pool = dummy)
|
|
||||||
// - All tokens accepted
|
// - All tokens accepted
|
||||||
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
|
@ -122,8 +121,7 @@ void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
|
|
||||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, int verbosity);
|
||||||
int verbosity);
|
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,6 @@ class GemmaTest : public ::testing::Test {
|
||||||
: weights("./2b-it-mqa.sbs"),
|
: weights("./2b-it-mqa.sbs"),
|
||||||
tokenizer("./tokenizer.spm"),
|
tokenizer("./tokenizer.spm"),
|
||||||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||||
inner_pool(0),
|
|
||||||
model_type(gcpp::Model::GEMMA_2B),
|
model_type(gcpp::Model::GEMMA_2B),
|
||||||
model(tokenizer, weights, model_type, pool) {
|
model(tokenizer, weights, model_type, pool) {
|
||||||
kv_cache = CreateKVCache(model_type);
|
kv_cache = CreateKVCache(model_type);
|
||||||
|
|
@ -60,8 +59,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
gcpp::GenerateGemma(
|
gcpp::GenerateGemma(
|
||||||
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
|
||||||
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||||
inner_pool, stream_token,
|
stream_token, /*accept=*/[](int) { return true; }, gen,
|
||||||
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0);
|
/*verbosity=*/0);
|
||||||
std::string response_text;
|
std::string response_text;
|
||||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||||
return response_text;
|
return response_text;
|
||||||
|
|
@ -71,8 +70,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||||
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
||||||
kv_cache, pool, inner_pool,
|
kv_cache, pool, /*verbosity=*/0) /
|
||||||
/*verbosity=*/0) /
|
|
||||||
prompt_string.size();
|
prompt_string.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -89,7 +87,6 @@ class GemmaTest : public ::testing::Test {
|
||||||
gcpp::Path tokenizer;
|
gcpp::Path tokenizer;
|
||||||
gcpp::KVCache kv_cache;
|
gcpp::KVCache kv_cache;
|
||||||
hwy::ThreadPool pool;
|
hwy::ThreadPool pool;
|
||||||
hwy::ThreadPool inner_pool;
|
|
||||||
gcpp::Model model_type = {};
|
gcpp::Model model_type = {};
|
||||||
gcpp::Gemma model;
|
gcpp::Gemma model;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -94,9 +94,8 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
|
|
||||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
const InferenceArgs& args, int verbosity,
|
||||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||||
std::string& eot_line) {
|
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
size_t abs_pos = 0; // absolute token index over all turns
|
size_t abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
|
|
@ -209,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
args.temperature, prompt, abs_pos, kv_cache, pool,
|
||||||
stream_token, accept_token, gen, verbosity);
|
stream_token, accept_token, gen, verbosity);
|
||||||
const double time_end = hwy::platform::Now();
|
const double time_end = hwy::platform::Now();
|
||||||
const double tok_sec = current_pos / (time_end - time_start);
|
const double tok_sec = current_pos / (time_end - time_start);
|
||||||
|
|
@ -229,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
hwy::ThreadPool inner_pool(0);
|
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
// For many-core, pinning threads to cores helps.
|
// For many-core, pinning threads to cores helps.
|
||||||
if (app.num_threads > 10) {
|
if (app.num_threads > 10) {
|
||||||
|
|
@ -271,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(
|
ReplGemma(
|
||||||
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference,
|
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
|
||||||
app.verbosity,
|
|
||||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue