Simplify threading: remove the use of inner_pool.

We only used inner_pool in the prefill FFW function, and there we
can achieve sufficient parallelism on the rows of the matrix-vector
multiplications.

Benchmark results on a 1600-token summarization task:

```
               Prefill speed
Num threads    BEFORE         AFTER
4               9.24 t/s       9.76 t/s
18             31.41 t/s      31.16 t/s
32             31.41 t/s      45.13 t/s
64             31.03 t/s      57.85 t/s
```
This commit is contained in:
Zoltan Szabadka 2024-04-29 16:07:30 +00:00
parent 1d18c5a129
commit 27117cc39f
6 changed files with 92 additions and 122 deletions

View File

@ -27,8 +27,8 @@ class PromptArgs : public gcpp::ArgsBase<PromptArgs> {
std::pair<std::string, int> QueryModel( std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
const std::string& input, gcpp::LayersOutputT* layers_output) { gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -55,8 +55,7 @@ std::pair<std::string, int> QueryModel(
} }
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
inner_pool, stream_token, accept_token, gen, app.verbosity, stream_token, accept_token, gen, app.verbosity, layers_output);
layers_output);
return {res, total_tokens}; return {res, total_tokens};
} }
@ -92,7 +91,6 @@ int main(int argc, char** argv) {
gcpp::LayersOutputT* layers_output = gcpp::LayersOutputT* layers_output =
log_layers_output ? &json_logger.layers_output_log_f : nullptr; log_layers_output ? &json_logger.layers_output_log_f : nullptr;
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps. // For many-core, pinning threads to cores helps.
if (app.num_threads > 10) { if (app.num_threads > 10) {
@ -112,7 +110,7 @@ int main(int argc, char** argv) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
const auto [answer, token_count] = QueryModel( const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, inner_pool, pool, prompt, layers_output); model, args, app, kv_cache, pool, prompt, layers_output);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush; std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
if (log_layers_output) { if (log_layers_output) {

View File

@ -58,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) {
std::pair<std::string, int> QueryModel( std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
const std::string& input) {
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -90,7 +89,7 @@ std::pair<std::string, int> QueryModel(
} }
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool,
inner_pool, stream_token, accept_token, gen, app.verbosity); stream_token, accept_token, gen, app.verbosity);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }
@ -131,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) {
int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache, gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const std::string& golden_path) {
const std::string& golden_path) {
const std::vector<std::pair<std::string, std::string>> queries_answers = const std::vector<std::pair<std::string, std::string>> queries_answers =
load_goldens(golden_path); load_goldens(golden_path);
int correct_answers = 0; int correct_answers = 0;
@ -140,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
for (const auto& [question, expected_answer] : queries_answers) { for (const auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, inner_pool, pool, question); QueryModel(model, args, app, kv_cache, pool, question);
total_tokens += token_count; total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) { if (answer.find(expected_answer) != std::string::npos) {
correct_answers++; correct_answers++;
@ -164,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args,
int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache, gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const gcpp::Path& text) {
const gcpp::Path& text) {
std::string prompt("Here is some text to summarize:\n"); std::string prompt("Here is some text to summarize:\n");
prompt.append(ReadFile(text)); prompt.append(ReadFile(text));
prompt.append("\nSummarize this text.\n"); prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
const auto [answer, token_count] = const auto [answer, token_count] =
QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt); QueryModel(model, args, app, kv_cache, pool, prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush; std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count); LogSpeedStats(time_start, token_count);
return EXIT_SUCCESS; return EXIT_SUCCESS;
@ -179,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args,
int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const gcpp::Path& text,
const gcpp::Path& text, size_t batch_tokens) { size_t batch_tokens) {
std::string input = ReadFile(text); std::string input = ReadFile(text);
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
@ -197,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
auto kv_cache = CreateKVCache(model_type); auto kv_cache = CreateKVCache(model_type);
float entropy = float entropy =
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool, ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
inner_pool, app.verbosity); app.verbosity);
total_entropy += entropy; total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens); LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice; std::string text_slice;
@ -211,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
gcpp::AppArgs& app, gcpp::KVCache& kv_cache, gcpp::AppArgs& app, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& pool, const gcpp::Path& json_file,
const gcpp::Path& json_file, size_t max_questions) { size_t max_questions) {
std::ifstream trivia_file(json_file.path); std::ifstream trivia_file(json_file.path);
if (!trivia_file) { if (!trivia_file) {
std::cout << "Could not load file: " << json_file.path << "\n" std::cout << "Could not load file: " << json_file.path << "\n"
@ -225,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args,
while (std::getline(trivia_file, line)) { while (std::getline(trivia_file, line)) {
json data = json::parse(line); json data = json::parse(line);
const auto [answer, token_count] = QueryModel( const auto [answer, token_count] = QueryModel(
model, args, app, kv_cache, inner_pool, pool, data["question"]); model, args, app, kv_cache, pool, data["question"]);
std::cout << answer << "\n"; std::cout << answer << "\n";
bool correct = false; bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) { for (const std::string expected : data["answer"]["aliases"]) {
@ -263,7 +260,6 @@ int main(int argc, char** argv) {
HWY_ABORT("\nInvalid inference args: %s", error); HWY_ABORT("\nInvalid inference args: %s", error);
} }
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps. // For many-core, pinning threads to cores helps.
if (app.num_threads > 10) { if (app.num_threads > 10) {
@ -280,17 +276,16 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.path.empty()) { if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path = const std::string golden_path =
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool, return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path);
golden_path);
} else if (!benchmark_args.summarize_text.path.empty()) { } else if (!benchmark_args.summarize_text.path.empty()) {
return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool, return BenchmarkSummary(model, args, app, kv_cache, pool,
benchmark_args.summarize_text); benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.path.empty()) { } else if (!benchmark_args.cross_entropy.path.empty()) {
return BenchmarkCrossEntropy(model, loader.ModelType(), args, app, return BenchmarkCrossEntropy(model, loader.ModelType(), args, app,
inner_pool, pool, benchmark_args.cross_entropy, pool, benchmark_args.cross_entropy,
benchmark_args.batch_tokens); benchmark_args.batch_tokens);
} else if (!benchmark_args.trivia_qa.path.empty()) { } else if (!benchmark_args.trivia_qa.path.empty()) {
return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool, return BenchmarkTriviaQA(model, args, app, kv_cache, pool,
benchmark_args.trivia_qa, benchmark_args.trivia_qa,
benchmark_args.max_questions); benchmark_args.max_questions);
} }

View File

@ -440,15 +440,13 @@ struct GemmaInterface {
virtual void Generate(size_t max_tokens, size_t max_generated_tokens, virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) = 0; int verbosity, LayersOutputT* layers_output) = 0;
virtual float ComputeCrossEntropy(size_t max_tokens, virtual float ComputeCrossEntropy(size_t max_tokens,
const std::vector<int>& prompt, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
int verbosity) = 0; int verbosity) = 0;
}; };
@ -535,13 +533,12 @@ struct GemmaImpl : public GemmaInterface {
void Generate(size_t max_tokens, size_t max_generated_tokens, void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity, std::mt19937&, int verbosity,
LayersOutputT* layers_output) override; LayersOutputT* layers_output) override;
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt, float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
int verbosity) override; int verbosity) override;
GemmaTokenizerImpl tokenizer; GemmaTokenizerImpl tokenizer;
@ -880,8 +877,7 @@ template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights, const WeightArrayT& weights,
Activations<TConfig, kBatchSize>& activations, Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool) {
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
@ -924,19 +920,17 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
} }
// TODO: sink the loop into these functions, i.e. make them MatMul. // TODO: sink the loop into these functions, i.e. make them MatMul.
pool.Run( for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
0, num_tokens, AddFrom(activations.att_post2.data() + token_idx * kModelDim,
[&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { activations.x.data() + token_idx * kModelDim, kModelDim);
AddFrom(activations.att_post2.data() + token_idx * kModelDim, RMSNorm(activations.x.data() + token_idx * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); layer_weights->pre_ffw_norm_scale.data(),
RMSNorm(activations.x.data() + token_idx * kModelDim, activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
layer_weights->pre_ffw_norm_scale.data(), kModelDim);
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, FFW<kBatchSize>(activations, token_idx, layer_weights, pool);
kModelDim); AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
FFW<kBatchSize>(activations, token_idx, layer_weights, inner_pool); activations.x.data() + token_idx * kModelDim, kModelDim);
AddFrom(activations.ffw_out.data() + token_idx * kModelDim, }
activations.x.data() + token_idx * kModelDim, kModelDim);
});
} // foreach layer } // foreach layer
pool.Run( pool.Run(
@ -950,8 +944,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
template <typename WeightArrayT, class TConfig> template <typename WeightArrayT, class TConfig>
void Transformer(int token, size_t pos, const WeightArrayT& weights, void Transformer(int token, size_t pos, const WeightArrayT& weights,
Activations<TConfig, 1>& activations, KVCache& kv_cache, Activations<TConfig, 1>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, LayersOutputT* layers_output) {
LayersOutputT* layers_output) {
if (layers_output != nullptr) { if (layers_output != nullptr) {
float token_f = token; float token_f = token;
(*layers_output)(pos, "Tokens", &token_f, 1); (*layers_output)(pos, "Tokens", &token_f, 1);
@ -1033,8 +1026,7 @@ template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens, void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) { int verbosity, LayersOutputT* layers_output) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
@ -1077,7 +1069,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset; const int* batch_tokens = prompt.data() + pos_offset;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights, Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool); prefill_activations, kv_cache, pool);
for (size_t idx = 0; idx < batch_size; ++idx) { for (size_t idx = 0; idx < batch_size; ++idx) {
if (!stream_token(batch_tokens[idx], 0.0f)) return; if (!stream_token(batch_tokens[idx], 0.0f)) return;
} }
@ -1105,7 +1097,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
const bool is_generating_phase = pos_offset >= prompt_size - 1; const bool is_generating_phase = pos_offset >= prompt_size - 1;
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, Transformer(token, pos, weights, activations, kv_cache, pool,
layers_output); layers_output);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above. // The condition below is always true if we are doing Prefill above.
@ -1171,8 +1163,7 @@ void LogTopK(GemmaImpl<TConfig>& gemma, float* logits, float* dist, size_t len,
template <class TConfig> template <class TConfig>
float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens, float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& pool, int verbosity) {
hwy::ThreadPool& inner_pool, int verbosity) {
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
Activations<TConfig, 1>& activations = *gemma.state.get(); Activations<TConfig, 1>& activations = *gemma.state.get();
@ -1196,7 +1187,7 @@ float ComputeCrossEntropyImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
total_entropy / std::log(2.0) / (pos + 1)); total_entropy / std::log(2.0) / (pos + 1));
} }
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, Transformer(token, pos, weights, activations, kv_cache, pool,
/*layers_output=*/nullptr); /*layers_output=*/nullptr);
MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0, MatVec<kVocabSize, kModelDim>(weights.embedder_input_embedding, 0,
activations.x.data(), activations.x.data(),
@ -1215,62 +1206,59 @@ void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
const AcceptFunc& accept_token, std::mt19937& gen, std::mt19937& gen, int verbosity,
int verbosity, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, stream_token, accept_token, gen,
accept_token, gen, verbosity, layers_output); verbosity, layers_output);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const StreamFunc& stream_token, const AcceptFunc& accept_token,
const AcceptFunc& accept_token, std::mt19937& gen, std::mt19937& gen, int verbosity,
int verbosity, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, stream_token, accept_token, gen,
accept_token, gen, verbosity, layers_output); verbosity, layers_output);
} }
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens, void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) { int verbosity, LayersOutputT* layers_output) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, stream_token, accept_token, gen,
accept_token, gen, verbosity, layers_output); verbosity, layers_output);
} }
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, int verbosity) {
int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
inner_pool, verbosity); verbosity);
} }
float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens, float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, int verbosity) {
int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
inner_pool, verbosity); verbosity);
} }
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
size_t max_tokens, size_t max_tokens,
const std::vector<int>& prompt, const std::vector<int>& prompt,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, int verbosity) { int verbosity) {
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
inner_pool, verbosity); verbosity);
} }
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null // Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@ -1507,12 +1495,12 @@ template <>
void GemmaImpl<ConfigGemma2B>::Generate( void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature, size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate2B) HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, kv_cache, pool, stream_token, accept_token, gen, verbosity,
layers_output); layers_output);
} }
@ -1520,50 +1508,49 @@ template <>
void GemmaImpl<ConfigGemma7B>::Generate( void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature, size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(Generate7B) HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output);
layers_output);
} }
template <> template <>
void GemmaImpl<ConfigGriffin2B>::Generate( void GemmaImpl<ConfigGriffin2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature, size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, const StreamFunc& stream_token,
const StreamFunc& stream_token, const AcceptFunc& accept_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity,
std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, kv_cache, pool, stream_token, accept_token, gen, verbosity,
layers_output); layers_output);
} }
template <> template <>
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy( float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache, size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)( return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); *this, max_tokens, prompt, kv_cache, pool, verbosity);
} }
template <> template <>
float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy( float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache, size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)( return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); *this, max_tokens, prompt, kv_cache, pool, verbosity);
} }
template <> template <>
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy( float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache, size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { hwy::ThreadPool& pool, int verbosity) {
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)( return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); *this, max_tokens, prompt, kv_cache, pool, verbosity);
} }
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
@ -1607,13 +1594,13 @@ const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); }
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output) { int verbosity, LayersOutputT* layers_output) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin); pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, kv_cache, pool, inner_pool, stream_token, start_pos, kv_cache, pool, stream_token, accept_token,
accept_token, gen, verbosity, layers_output); gen, verbosity, layers_output);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
@ -1621,10 +1608,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen) { const StreamFunc& stream_token, std::mt19937& gen) {
hwy::ThreadPool inner_pool(0);
GenerateGemma( GenerateGemma(
gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens,
runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, runtime_config.temperature, prompt, start_pos, kv_cache, pool,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity, stream_token, [](int) { return true; }, gen, runtime_config.verbosity,
/*layers_output=*/nullptr); /*layers_output=*/nullptr);
} }
@ -1637,11 +1623,10 @@ void CompressWeights(gcpp::Model model, const Path& weights,
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, int verbosity) {
int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin); pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
const float result = gemma.impl_->ComputeCrossEntropy( const float result = gemma.impl_->ComputeCrossEntropy(
max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); max_tokens, prompt, kv_cache, pool, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
return result; return result;
} }

View File

@ -104,13 +104,12 @@ using AcceptFunc = std::function<bool(int)>;
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity, LayersOutputT* layers_output = nullptr); int verbosity, LayersOutputT* layers_output = nullptr);
// Convenience function for the common case: // Convenience function for the common case:
// - Bundle runtime parameters as RuntimeConfig // - Bundle runtime parameters as RuntimeConfig
// - No ThreadPool within ThreadPool (inner_pool = dummy)
// - All tokens accepted // - All tokens accepted
void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
@ -122,8 +121,7 @@ void CompressWeights(gcpp::Model model, const Path& weights,
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, int verbosity);
int verbosity);
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;

View File

@ -36,7 +36,6 @@ class GemmaTest : public ::testing::Test {
: weights("./2b-it-mqa.sbs"), : weights("./2b-it-mqa.sbs"),
tokenizer("./tokenizer.spm"), tokenizer("./tokenizer.spm"),
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)), pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
inner_pool(0),
model_type(gcpp::Model::GEMMA_2B), model_type(gcpp::Model::GEMMA_2B),
model(tokenizer, weights, model_type, pool) { model(tokenizer, weights, model_type, pool) {
kv_cache = CreateKVCache(model_type); kv_cache = CreateKVCache(model_type);
@ -60,8 +59,8 @@ class GemmaTest : public ::testing::Test {
gcpp::GenerateGemma( gcpp::GenerateGemma(
model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048, model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048,
/*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool, /*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool,
inner_pool, stream_token, stream_token, /*accept=*/[](int) { return true; }, gen,
/*accept=*/[](int) { return true; }, gen, /*verbosity=*/0); /*verbosity=*/0);
std::string response_text; std::string response_text;
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text)); HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
return response_text; return response_text;
@ -71,8 +70,7 @@ class GemmaTest : public ::testing::Test {
std::vector<int> prompt; std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt)); HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
kv_cache, pool, inner_pool, kv_cache, pool, /*verbosity=*/0) /
/*verbosity=*/0) /
prompt_string.size(); prompt_string.size();
} }
@ -89,7 +87,6 @@ class GemmaTest : public ::testing::Test {
gcpp::Path tokenizer; gcpp::Path tokenizer;
gcpp::KVCache kv_cache; gcpp::KVCache kv_cache;
hwy::ThreadPool pool; hwy::ThreadPool pool;
hwy::ThreadPool inner_pool;
gcpp::Model model_type = {}; gcpp::Model model_type = {};
gcpp::Gemma model; gcpp::Gemma model;
}; };

View File

@ -94,9 +94,8 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
void ReplGemma(gcpp::Gemma& model, ModelTraining training, void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args, const InferenceArgs& args, int verbosity,
int verbosity, const gcpp::AcceptFunc& accept_token, const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // absolute token index over all turns size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -209,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, args.temperature, prompt, abs_pos, kv_cache, pool,
stream_token, accept_token, gen, verbosity); stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now(); const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start); const double tok_sec = current_pos / (time_end - time_start);
@ -229,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
hwy::ThreadPool inner_pool(0);
hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning threads to cores helps. // For many-core, pinning threads to cores helps.
if (app.num_threads > 10) { if (app.num_threads > 10) {
@ -271,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma( ReplGemma(
model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference, model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }