diff --git a/debug_prompt.cc b/debug_prompt.cc index dd962af..84a709d 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -27,8 +27,8 @@ class PromptArgs : public gcpp::ArgsBase { std::pair QueryModel( gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const std::string& input, gcpp::LayersOutputT* layers_output) { + gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input, + gcpp::LayersOutputT* layers_output) { std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); @@ -55,8 +55,7 @@ std::pair QueryModel( } GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - inner_pool, stream_token, accept_token, gen, app.verbosity, - layers_output); + stream_token, accept_token, gen, app.verbosity, layers_output); return {res, total_tokens}; } @@ -92,7 +91,6 @@ int main(int argc, char** argv) { gcpp::LayersOutputT* layers_output = log_layers_output ? &json_logger.layers_output_log_f : nullptr; - hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. if (app.num_threads > 10) { @@ -112,7 +110,7 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } const auto [answer, token_count] = QueryModel( - model, args, app, kv_cache, inner_pool, pool, prompt, layers_output); + model, args, app, kv_cache, pool, prompt, layers_output); std::cout << answer.substr(prompt.size()) << "\n" << std::flush; if (log_layers_output) { diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index b5fb375..fdcf1b2 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -58,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) { std::pair QueryModel( gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const std::string& input) { + gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) { std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); @@ -90,7 +89,7 @@ std::pair QueryModel( } GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - inner_pool, stream_token, accept_token, gen, app.verbosity); + stream_token, accept_token, gen, app.verbosity); if (app.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } @@ -131,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) { int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const std::string& golden_path) { + hwy::ThreadPool& pool, const std::string& golden_path) { const std::vector> queries_answers = load_goldens(golden_path); int correct_answers = 0; @@ -140,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, const double time_start = hwy::platform::Now(); for (const auto& [question, expected_answer] : queries_answers) { const auto [answer, token_count] = - QueryModel(model, args, app, kv_cache, inner_pool, pool, question); + QueryModel(model, args, app, kv_cache, pool, question); total_tokens += token_count; if (answer.find(expected_answer) != std::string::npos) { correct_answers++; @@ -164,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const gcpp::Path& text) { + hwy::ThreadPool& pool, const gcpp::Path& text) { std::string prompt("Here is some text to summarize:\n"); prompt.append(ReadFile(text)); prompt.append("\nSummarize this text.\n"); const double time_start = hwy::platform::Now(); const auto [answer, token_count] = - QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt); + QueryModel(model, args, app, kv_cache, pool, prompt); std::cout << answer.substr(prompt.size()) << "\n" << std::flush; LogSpeedStats(time_start, token_count); return EXIT_SUCCESS; @@ -179,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const gcpp::Path& text, size_t batch_tokens) { + hwy::ThreadPool& pool, const gcpp::Path& text, + size_t batch_tokens) { std::string input = ReadFile(text); std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); @@ -197,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, auto kv_cache = CreateKVCache(model_type); float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool, - inner_pool, app.verbosity); + app.verbosity); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice; @@ -211,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const gcpp::Path& json_file, size_t max_questions) { + hwy::ThreadPool& pool, const gcpp::Path& json_file, + size_t max_questions) { std::ifstream trivia_file(json_file.path); if (!trivia_file) { std::cout << "Could not load file: " << json_file.path << "\n" @@ -225,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, while (std::getline(trivia_file, line)) { json data = json::parse(line); const auto [answer, token_count] = QueryModel( - model, args, app, kv_cache, inner_pool, pool, data["question"]); + model, args, app, kv_cache, pool, data["question"]); std::cout << answer << "\n"; bool correct = false; for (const std::string expected : data["answer"]["aliases"]) { @@ -263,7 +260,6 @@ int main(int argc, char** argv) { HWY_ABORT("\nInvalid inference args: %s", error); } - hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. if (app.num_threads > 10) { @@ -280,17 +276,16 @@ int main(int argc, char** argv) { if (!benchmark_args.goldens.path.empty()) { const std::string golden_path = benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; - return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool, - golden_path); + return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path); } else if (!benchmark_args.summarize_text.path.empty()) { - return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool, + return BenchmarkSummary(model, args, app, kv_cache, pool, benchmark_args.summarize_text); } else if (!benchmark_args.cross_entropy.path.empty()) { return BenchmarkCrossEntropy(model, loader.ModelType(), args, app, - inner_pool, pool, benchmark_args.cross_entropy, + pool, benchmark_args.cross_entropy, benchmark_args.batch_tokens); } else if (!benchmark_args.trivia_qa.path.empty()) { - return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool, + return BenchmarkTriviaQA(model, args, app, kv_cache, pool, benchmark_args.trivia_qa, benchmark_args.max_questions); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 58078a9..434576d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -440,15 +440,13 @@ struct GemmaInterface { virtual void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, + hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) = 0; virtual float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) = 0; }; @@ -535,13 +533,12 @@ struct GemmaImpl : public GemmaInterface { void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937&, int verbosity, LayersOutputT* layers_output) override; float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) override; GemmaTokenizerImpl tokenizer; @@ -880,8 +877,7 @@ template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, Activations& activations, - KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool) { + KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = @@ -924,19 +920,17 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, } // TODO: sink the loop into these functions, i.e. make them MatMul. - pool.Run( - 0, num_tokens, - [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { - AddFrom(activations.att_post2.data() + token_idx * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - RMSNorm(activations.x.data() + token_idx * kModelDim, - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, - kModelDim); - FFW(activations, token_idx, layer_weights, inner_pool); - AddFrom(activations.ffw_out.data() + token_idx * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - }); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + AddFrom(activations.att_post2.data() + token_idx * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + RMSNorm(activations.x.data() + token_idx * kModelDim, + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, + kModelDim); + FFW(activations, token_idx, layer_weights, pool); + AddFrom(activations.ffw_out.data() + token_idx * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + } } // foreach layer pool.Run( @@ -950,8 +944,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, template void Transformer(int token, size_t pos, const WeightArrayT& weights, Activations& activations, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - LayersOutputT* layers_output) { + hwy::ThreadPool& pool, LayersOutputT* layers_output) { if (layers_output != nullptr) { float token_f = token; (*layers_output)(pos, "Tokens", &token_f, 1); @@ -1033,8 +1026,7 @@ template void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, + hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; @@ -1077,7 +1069,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); const int* batch_tokens = prompt.data() + pos_offset; Prefill(batch_tokens, batch_size, pos, weights, - prefill_activations, kv_cache, pool, inner_pool); + prefill_activations, kv_cache, pool); for (size_t idx = 0; idx < batch_size; ++idx) { if (!stream_token(batch_tokens[idx], 0.0f)) return; } @@ -1105,7 +1097,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { const bool is_generating_phase = pos_offset >= prompt_size - 1; - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, + Transformer(token, pos, weights, activations, kv_cache, pool, layers_output); float* final_activation = activations.x.data(); // The condition below is always true if we are doing Prefill above. @@ -1171,8 +1163,7 @@ void LogTopK(GemmaImpl& gemma, float* logits, float* dist, size_t len, template float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; Activations& activations = *gemma.state.get(); @@ -1196,7 +1187,7 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, total_entropy / std::log(2.0) / (pos + 1)); } - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, + Transformer(token, pos, weights, activations, kv_cache, pool, /*layers_output=*/nullptr); MatVec(weights.embedder_input_embedding, 0, activations.x.data(), @@ -1215,62 +1206,59 @@ void Generate2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, gen, + verbosity, layers_output); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, gen, + verbosity, layers_output); } void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, gen, + verbosity, layers_output); } float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - inner_pool, verbosity); + verbosity); } float ComputeCrossEntropy7B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - inner_pool, verbosity); + verbosity); } float ComputeCrossEntropyGriffin2B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) { + int verbosity) { return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - inner_pool, verbosity); + verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -1507,12 +1495,12 @@ template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { + hwy::ThreadPool& pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); } @@ -1520,50 +1508,49 @@ template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { + hwy::ThreadPool& pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, - layers_output); + kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); } template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { + hwy::ThreadPool& pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); } template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)( - *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + *this, max_tokens, prompt, kv_cache, pool, verbosity); } template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)( - *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + *this, max_tokens, prompt, kv_cache, pool, verbosity); } template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)( - *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + *this, max_tokens, prompt, kv_cache, pool, verbosity); } Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, @@ -1607,13 +1594,13 @@ const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); } void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, + gen, verbosity, layers_output); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } @@ -1621,10 +1608,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, std::mt19937& gen) { - hwy::ThreadPool inner_pool(0); GenerateGemma( gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, - runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, + runtime_config.temperature, prompt, start_pos, kv_cache, pool, stream_token, [](int) { return true; }, gen, runtime_config.verbosity, /*layers_output=*/nullptr); } @@ -1637,11 +1623,10 @@ void CompressWeights(gcpp::Model model, const Path& weights, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity) { + hwy::ThreadPool& pool, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); const float result = gemma.impl_->ComputeCrossEntropy( - max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + max_tokens, prompt, kv_cache, pool, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); return result; } diff --git a/gemma/gemma.h b/gemma/gemma.h index 60d7843..e1689d9 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -104,13 +104,12 @@ using AcceptFunc = std::function; void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output = nullptr); // Convenience function for the common case: // - Bundle runtime parameters as RuntimeConfig -// - No ThreadPool within ThreadPool (inner_pool = dummy) // - All tokens accepted void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, @@ -122,8 +121,7 @@ void CompressWeights(gcpp::Model model, const Path& weights, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity); + hwy::ThreadPool& pool, int verbosity); constexpr int EOS_ID = 1; diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index a842a9c..258352b 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -36,7 +36,6 @@ class GemmaTest : public ::testing::Test { : weights("./2b-it-mqa.sbs"), tokenizer("./tokenizer.spm"), pool(std::min(20, (std::thread::hardware_concurrency() - 1) / 2)), - inner_pool(0), model_type(gcpp::Model::GEMMA_2B), model(tokenizer, weights, model_type, pool) { kv_cache = CreateKVCache(model_type); @@ -60,8 +59,8 @@ class GemmaTest : public ::testing::Test { gcpp::GenerateGemma( model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048, /*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool, - inner_pool, stream_token, - /*accept=*/[](int) { return true; }, gen, /*verbosity=*/0); + stream_token, /*accept=*/[](int) { return true; }, gen, + /*verbosity=*/0); std::string response_text; HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text)); return response_text; @@ -71,8 +70,7 @@ class GemmaTest : public ::testing::Test { std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt)); return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, - kv_cache, pool, inner_pool, - /*verbosity=*/0) / + kv_cache, pool, /*verbosity=*/0) / prompt_string.size(); } @@ -89,7 +87,6 @@ class GemmaTest : public ::testing::Test { gcpp::Path tokenizer; gcpp::KVCache kv_cache; hwy::ThreadPool pool; - hwy::ThreadPool inner_pool; gcpp::Model model_type = {}; gcpp::Gemma model; }; diff --git a/gemma/run.cc b/gemma/run.cc index aa5f9bb..942bff8 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -94,9 +94,8 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, void ReplGemma(gcpp::Gemma& model, ModelTraining training, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -209,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, const double time_start = hwy::platform::Now(); GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + args.temperature, prompt, abs_pos, kv_cache, pool, stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); @@ -229,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { PROFILER_ZONE("Run.misc"); - hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. if (app.num_threads > 10) { @@ -271,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference, - app.verbosity, + model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); }