diff --git a/BUILD.bazel b/BUILD.bazel index bca3fd8..b3c1090 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -46,8 +46,10 @@ cc_test( deps = [ ":ops", "@googletest//:gtest_main", + "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", + "@hwy//:thread_pool", ], ) diff --git a/debug_prompt.cc b/debug_prompt.cc index dd962af..e844311 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -7,6 +7,7 @@ #include "nlohmann/json.hpp" #include "util/app.h" #include "util/args.h" +#include "hwy/contrib/thread_pool/thread_pool.h" using json = nlohmann::json; @@ -27,8 +28,8 @@ class PromptArgs : public gcpp::ArgsBase { std::pair QueryModel( gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const std::string& input, gcpp::LayersOutputT* layers_output) { + gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input, + gcpp::LayersOutputT* layers_output) { std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); @@ -55,8 +56,7 @@ std::pair QueryModel( } GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - inner_pool, stream_token, accept_token, gen, app.verbosity, - layers_output); + stream_token, accept_token, gen, app.verbosity, layers_output); return {res, total_tokens}; } @@ -92,7 +92,6 @@ int main(int argc, char** argv) { gcpp::LayersOutputT* layers_output = log_layers_output ? &json_logger.layers_output_log_f : nullptr; - hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. if (app.num_threads > 10) { @@ -112,7 +111,7 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } const auto [answer, token_count] = QueryModel( - model, args, app, kv_cache, inner_pool, pool, prompt, layers_output); + model, args, app, kv_cache, pool, prompt, layers_output); std::cout << answer.substr(prompt.size()) << "\n" << std::flush; if (log_layers_output) { diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index b5fb375..fdcf1b2 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -58,8 +58,7 @@ void LogSpeedStats(const double time_start, size_t total_tokens) { std::pair QueryModel( gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - gcpp::KVCache& kv_cache, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const std::string& input) { + gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) { std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); @@ -90,7 +89,7 @@ std::pair QueryModel( } GenerateGemma(model, args.max_tokens, args.max_generated_tokens, args.temperature, prompt, /*abs_pos=*/0, kv_cache, pool, - inner_pool, stream_token, accept_token, gen, app.verbosity); + stream_token, accept_token, gen, app.verbosity); if (app.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } @@ -131,8 +130,7 @@ std::string ReadFile(const gcpp::Path& path) { int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const std::string& golden_path) { + hwy::ThreadPool& pool, const std::string& golden_path) { const std::vector> queries_answers = load_goldens(golden_path); int correct_answers = 0; @@ -140,7 +138,7 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, const double time_start = hwy::platform::Now(); for (const auto& [question, expected_answer] : queries_answers) { const auto [answer, token_count] = - QueryModel(model, args, app, kv_cache, inner_pool, pool, question); + QueryModel(model, args, app, kv_cache, pool, question); total_tokens += token_count; if (answer.find(expected_answer) != std::string::npos) { correct_answers++; @@ -164,14 +162,13 @@ int BenchmarkGoldens(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const gcpp::Path& text) { + hwy::ThreadPool& pool, const gcpp::Path& text) { std::string prompt("Here is some text to summarize:\n"); prompt.append(ReadFile(text)); prompt.append("\nSummarize this text.\n"); const double time_start = hwy::platform::Now(); const auto [answer, token_count] = - QueryModel(model, args, app, kv_cache, inner_pool, pool, prompt); + QueryModel(model, args, app, kv_cache, pool, prompt); std::cout << answer.substr(prompt.size()) << "\n" << std::flush; LogSpeedStats(time_start, token_count); return EXIT_SUCCESS; @@ -179,8 +176,8 @@ int BenchmarkSummary(gcpp::Gemma& model, gcpp::InferenceArgs& args, int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, gcpp::InferenceArgs& args, gcpp::AppArgs& app, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const gcpp::Path& text, size_t batch_tokens) { + hwy::ThreadPool& pool, const gcpp::Path& text, + size_t batch_tokens) { std::string input = ReadFile(text); std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); @@ -197,7 +194,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, auto kv_cache = CreateKVCache(model_type); float entropy = ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool, - inner_pool, app.verbosity); + app.verbosity); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice; @@ -211,8 +208,8 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, - hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, - const gcpp::Path& json_file, size_t max_questions) { + hwy::ThreadPool& pool, const gcpp::Path& json_file, + size_t max_questions) { std::ifstream trivia_file(json_file.path); if (!trivia_file) { std::cout << "Could not load file: " << json_file.path << "\n" @@ -225,7 +222,7 @@ int BenchmarkTriviaQA(gcpp::Gemma& model, gcpp::InferenceArgs& args, while (std::getline(trivia_file, line)) { json data = json::parse(line); const auto [answer, token_count] = QueryModel( - model, args, app, kv_cache, inner_pool, pool, data["question"]); + model, args, app, kv_cache, pool, data["question"]); std::cout << answer << "\n"; bool correct = false; for (const std::string expected : data["answer"]["aliases"]) { @@ -263,7 +260,6 @@ int main(int argc, char** argv) { HWY_ABORT("\nInvalid inference args: %s", error); } - hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. if (app.num_threads > 10) { @@ -280,17 +276,16 @@ int main(int argc, char** argv) { if (!benchmark_args.goldens.path.empty()) { const std::string golden_path = benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt"; - return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool, - golden_path); + return BenchmarkGoldens(model, args, app, kv_cache, pool, golden_path); } else if (!benchmark_args.summarize_text.path.empty()) { - return BenchmarkSummary(model, args, app, kv_cache, inner_pool, pool, + return BenchmarkSummary(model, args, app, kv_cache, pool, benchmark_args.summarize_text); } else if (!benchmark_args.cross_entropy.path.empty()) { return BenchmarkCrossEntropy(model, loader.ModelType(), args, app, - inner_pool, pool, benchmark_args.cross_entropy, + pool, benchmark_args.cross_entropy, benchmark_args.batch_tokens); } else if (!benchmark_args.trivia_qa.path.empty()) { - return BenchmarkTriviaQA(model, args, app, kv_cache, inner_pool, pool, + return BenchmarkTriviaQA(model, args, app, kv_cache, pool, benchmark_args.trivia_qa, benchmark_args.max_questions); } diff --git a/gemma/compress_weights.cc b/gemma/compress_weights.cc index a4d414c..d9260b0 100644 --- a/gemma/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -70,7 +70,7 @@ struct Args : public ArgsBase { template void ForEach(const Visitor& visitor) { visitor(weights, "weights", Path(), - "Path name of model weights (.sbs) file.\n" + "Path to model weights (.bin) file.\n" " Required argument."); visitor(model_type_str, "model", std::string(), "Model type\n 2b-it = 2B parameters, instruction-tuned\n " @@ -80,7 +80,7 @@ struct Args : public ArgsBase { "gr2b-pt = griffin 2B parameters, pretrained\n " " Required argument."); visitor(compressed_weights, "compressed_weights", Path(), - "Path name where compressed weights file will be written.\n" + "Path name where compressed weights (.sbs) file will be written.\n" " Required argument."); visitor(num_threads, "num_threads", kDefaultNumThreads, // see ChooseNumThreads diff --git a/gemma/configs.h b/gemma/configs.h index bedecee..5bfa518 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -28,6 +28,11 @@ #define GEMMA_TOPK 1 #endif // !GEMMA_TOPK +// Allow changing upper bound on threads as a compiler flag +#ifndef GEMMA_MAX_THREADS +#define GEMMA_MAX_THREADS 128 +#endif // !GEMMA_MAX_THREADS + #include #include @@ -45,6 +50,7 @@ namespace gcpp { static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kTopK = GEMMA_TOPK; +static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS; enum class LayerAttentionType { kGemma, diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 58078a9..a92d835 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -399,9 +399,9 @@ struct Activations { static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; static constexpr size_t kCachePosSize = - TConfig::kGemmaLayers * kKVHeads * kQKVDim; - static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; + TConfig::kGemmaLayers * kCacheLayerSize; std::array x; // input std::array pre_att_rms_out; @@ -421,6 +421,10 @@ struct Activations { std::array ffw_out; std::array logits; + // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into + // per-thread storage. + std::array even_odd; + // Griffin layer internal activations static constexpr size_t kGriffinDim = TConfig::kGriffinLayers > 0 ? kModelDim : 0; @@ -440,15 +444,13 @@ struct GemmaInterface { virtual void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, + hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) = 0; virtual float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) = 0; }; @@ -535,13 +537,12 @@ struct GemmaImpl : public GemmaInterface { void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937&, int verbosity, LayersOutputT* layers_output) override; float ComputeCrossEntropy(size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) override; GemmaTokenizerImpl tokenizer; @@ -578,13 +579,14 @@ HWY_NOINLINE void GriffinRecurrent( gcpp::Activations::kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kHeads = TConfig::kHeads; + static constexpr bool kAdd = true; const size_t batch_offset = batch_idx * kModelDim; const size_t pos = batch_start + batch_idx; // X / Y linear layers. float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - TwoMatVecAdd( + TwoMatVecAdd( layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, activations.pre_att_rms_out.data() + batch_offset, /*add0=*/layer_weights->griffin.linear_x_biases.data(), @@ -634,7 +636,7 @@ HWY_NOINLINE void GriffinRecurrent( constexpr size_t kHeadDim = kModelDim / kHeads; constexpr size_t kMatrixSize = kHeadDim * kHeadDim; size_t head_offset = head * kHeadDim; - TwoOfsMatVecAddLoop( + TwoOfsMatVecAddLoop( layer_weights->griffin.gate_w, kMatrixSize * head, kMatrixSize * (kHeads + head), x + head_offset, /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, @@ -673,9 +675,10 @@ HWY_NOINLINE void GriffinRecurrent( // Final linear layer. float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; - MatVecAdd( + MatVecAdd( layer_weights->griffin.linear_out_w, 0, x, - layer_weights->griffin.linear_out_biases.data(), out_ptr, pool); + layer_weights->griffin.linear_out_biases.data(), + activations.even_odd.data(), out_ptr, pool); } template @@ -707,26 +710,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - - MatVecLoop(layer_weights->qkv_einsum_w, - head_offset + 0 * kQKVDim * kModelDim, x, q); - }; - - auto ProjKV = [&](size_t k_offset, size_t v_offset, - size_t kv_offset) HWY_ATTR { - float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset; - float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset; - - TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, k_offset, - v_offset, x, k, v); - - Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - }; - - auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { + auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR { // Calculate scores float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; @@ -741,7 +725,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t cache_offset = pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; + const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; } @@ -754,7 +738,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t cache_offset = pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; + float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } // linear projection from kQKVDim back to kModelDim, sum projections @@ -763,20 +747,21 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, head == 0 ? activations.att_post2.data() + batch_idx * kModelDim : activations.att_post1.data() + head * kBatchSize * kModelDim; + float* even_odd = activations.even_odd.data() + thread * kQKVDim; if (head == 0) { MatVecAddLoop( layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, - layer_weights->attention_output_biases.data(), head_out); + layer_weights->attention_output_biases.data(), even_odd, head_out); } else { MatVecLoop(layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, - head_out); + even_odd, head_out); } }; if constexpr (kHeads == kKVHeads) { // Multi-Head Attention - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { // linear projections to QKV const size_t head_offset = TConfig::kInterleaveQKV ? 3 * kQKVDim * kModelDim @@ -787,28 +772,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t k_offset = head * head_offset + 1 * mat_offset; const size_t v_offset = head * head_offset + 2 * mat_offset; - ProjQ(head, q_offset); + // ProjQ + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; + MatVecLoop( + layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x, + activations.even_odd.data() + thread * kModelDim, q); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + // ProjKV + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head * kQKVDim * 2; + float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; + float* HWY_RESTRICT v = k + kQKVDim; + TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, + k_offset, v_offset, x, k, v); + Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - ProjKV(k_offset, v_offset, kv_offset); - - Attn(head, head * kQKVDim); + Attn(head, head * kQKVDim * 2, thread); }); } else { // Multi-Query Attention - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; - constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; - constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize; + float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; + MatVec(layer_weights->qkv_einsum_w, 0, x, + activations.even_odd.data(), q, pool); - ProjKV(k_offset, v_offset, kv_offset); + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + + cache_pos * kCachePosSize + + layer * kCacheLayerSize; + MatVec(layer_weights->qkv_einsum_w, + kHeads * kQKVDim * kModelDim, x, + activations.even_odd.data(), kv, pool); - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - ProjQ(head, head * kQKVDim * kModelDim); - Attn(head, 0); + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + + pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { + Attn(head, 0, thread); }); } @@ -828,6 +826,7 @@ HWY_NOINLINE void FFW(Activations& activations, static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; + float* HWY_RESTRICT even_odd = activations.even_odd.data(); { PROFILER_ZONE("Gen.FFW.GatedGELU"); @@ -836,15 +835,15 @@ HWY_NOINLINE void FFW(Activations& activations, float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - // Same matrix, first and second half of rows. Could fuse into one MatVec, - // but separating them could help on NUMA e.g. multiple sockets. + // Same matrix, first and second half of rows. Could fuse into one MatVec. MatVecAdd( layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool); + layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, + out_mul, pool); // Gate, will go through the nonlinearity. MatVecAdd( layer_weights->gating_einsum_w, 0, vec, - layer_weights->ffw_gating_biases.data(), out, pool); + layer_weights->ffw_gating_biases.data(), even_odd, out, pool); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -857,7 +856,7 @@ HWY_NOINLINE void FFW(Activations& activations, PROFILER_ZONE("Gen.FFW\\GatedGELU"); MatVecAdd( layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, - layer_weights->ffw_output_biases.data(), + layer_weights->ffw_output_biases.data(), even_odd, activations.ffw_out.data() + batch_idx * kModelDim, pool); } @@ -880,8 +879,7 @@ template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, Activations& activations, - KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool) { + KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = @@ -924,19 +922,17 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, } // TODO: sink the loop into these functions, i.e. make them MatMul. - pool.Run( - 0, num_tokens, - [&](const uint64_t token_idx, size_t thread_id) HWY_ATTR { - AddFrom(activations.att_post2.data() + token_idx * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - RMSNorm(activations.x.data() + token_idx * kModelDim, - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, - kModelDim); - FFW(activations, token_idx, layer_weights, inner_pool); - AddFrom(activations.ffw_out.data() + token_idx * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - }); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + AddFrom(activations.att_post2.data() + token_idx * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + RMSNorm(activations.x.data() + token_idx * kModelDim, + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, + kModelDim); + FFW(activations, token_idx, layer_weights, pool); + AddFrom(activations.ffw_out.data() + token_idx * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + } } // foreach layer pool.Run( @@ -950,8 +946,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, template void Transformer(int token, size_t pos, const WeightArrayT& weights, Activations& activations, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - LayersOutputT* layers_output) { + hwy::ThreadPool& pool, LayersOutputT* layers_output) { if (layers_output != nullptr) { float token_f = token; (*layers_output)(pos, "Tokens", &token_f, 1); @@ -1033,8 +1028,7 @@ template void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, + hwy::ThreadPool& pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; @@ -1077,7 +1071,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); const int* batch_tokens = prompt.data() + pos_offset; Prefill(batch_tokens, batch_size, pos, weights, - prefill_activations, kv_cache, pool, inner_pool); + prefill_activations, kv_cache, pool); for (size_t idx = 0; idx < batch_size; ++idx) { if (!stream_token(batch_tokens[idx], 0.0f)) return; } @@ -1105,7 +1099,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { const bool is_generating_phase = pos_offset >= prompt_size - 1; - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, + Transformer(token, pos, weights, activations, kv_cache, pool, layers_output); float* final_activation = activations.x.data(); // The condition below is always true if we are doing Prefill above. @@ -1114,9 +1108,9 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, if (is_generating_phase) { PROFILER_ZONE("Gen.Embedding"); // Generation phase - MatVec(weights.embedder_input_embedding, - 0, final_activation, - activations.logits.data(), pool); + MatVec( + weights.embedder_input_embedding, 0, final_activation, + activations.even_odd.data(), activations.logits.data(), pool); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, @@ -1171,8 +1165,7 @@ void LogTopK(GemmaImpl& gemma, float* logits, float* dist, size_t len, template float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; Activations& activations = *gemma.state.get(); @@ -1196,11 +1189,11 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, total_entropy / std::log(2.0) / (pos + 1)); } - Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, + Transformer(token, pos, weights, activations, kv_cache, pool, /*layers_output=*/nullptr); - MatVec(weights.embedder_input_embedding, 0, - activations.x.data(), - activations.logits.data(), pool); + MatVec( + weights.embedder_input_embedding, 0, activations.x.data(), + activations.even_odd.data(), activations.logits.data(), pool); LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); memcpy(logits.data(), activations.logits.data(), kVocabSize * sizeof(logits[0])); @@ -1215,62 +1208,59 @@ void Generate2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, gen, + verbosity, layers_output); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, - int verbosity, LayersOutputT* layers_output) { + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, gen, + verbosity, layers_output); } void GenerateGriffin2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, gen, + verbosity, layers_output); } float ComputeCrossEntropy2B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - inner_pool, verbosity); + verbosity); } float ComputeCrossEntropy7B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - inner_pool, verbosity); + verbosity); } float ComputeCrossEntropyGriffin2B(GemmaImpl& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, int verbosity) { + int verbosity) { return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool, - inner_pool, verbosity); + verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -1477,9 +1467,8 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size) { KVCache kv_cache = {}; if (size_cache_pos != 0) { - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = - hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.kv_cache = + hwy::AllocateAligned(seq_len * size_cache_pos * 2); } if (conv1d_cache_size != 0) { kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); @@ -1507,12 +1496,12 @@ template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { + hwy::ThreadPool& pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); } @@ -1520,50 +1509,49 @@ template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { + hwy::ThreadPool& pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, - layers_output); + kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); } template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { + hwy::ThreadPool& pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, + LayersOutputT* layers_output) { HWY_DYNAMIC_DISPATCH(GenerateGriffin2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity, + kv_cache, pool, stream_token, accept_token, gen, verbosity, layers_output); } template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy2B)( - *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + *this, max_tokens, prompt, kv_cache, pool, verbosity); } template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropy7B)( - *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + *this, max_tokens, prompt, kv_cache, pool, verbosity); } template <> float GemmaImpl::ComputeCrossEntropy( size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) { + hwy::ThreadPool& pool, int verbosity) { return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)( - *this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + *this, max_tokens, prompt, kv_cache, pool, verbosity); } Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, @@ -1607,13 +1595,13 @@ const GemmaTokenizer* Gemma::Tokenizer() const { return impl_->Tokenizer(); } void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, - start_pos, kv_cache, pool, inner_pool, stream_token, - accept_token, gen, verbosity, layers_output); + start_pos, kv_cache, pool, stream_token, accept_token, + gen, verbosity, layers_output); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } @@ -1621,10 +1609,9 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, std::mt19937& gen) { - hwy::ThreadPool inner_pool(0); GenerateGemma( gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, - runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, + runtime_config.temperature, prompt, start_pos, kv_cache, pool, stream_token, [](int) { return true; }, gen, runtime_config.verbosity, /*layers_output=*/nullptr); } @@ -1637,11 +1624,10 @@ void CompressWeights(gcpp::Model model, const Path& weights, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity) { + hwy::ThreadPool& pool, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); const float result = gemma.impl_->ComputeCrossEntropy( - max_tokens, prompt, kv_cache, pool, inner_pool, verbosity); + max_tokens, prompt, kv_cache, pool, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); return result; } diff --git a/gemma/gemma.h b/gemma/gemma.h index 60d7843..822f75b 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -44,9 +44,7 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim - hwy::AlignedFreeUniquePtr - value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim + kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2 hwy::AlignedFreeUniquePtr conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr @@ -104,13 +102,12 @@ using AcceptFunc = std::function; void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity, LayersOutputT* layers_output = nullptr); // Convenience function for the common case: // - Bundle runtime parameters as RuntimeConfig -// - No ThreadPool within ThreadPool (inner_pool = dummy) // - All tokens accepted void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, const std::vector& prompt, size_t start_pos, @@ -122,8 +119,7 @@ void CompressWeights(gcpp::Model model, const Path& weights, float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - int verbosity); + hwy::ThreadPool& pool, int verbosity); constexpr int EOS_ID = 1; diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index a842a9c..258352b 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -36,7 +36,6 @@ class GemmaTest : public ::testing::Test { : weights("./2b-it-mqa.sbs"), tokenizer("./tokenizer.spm"), pool(std::min(20, (std::thread::hardware_concurrency() - 1) / 2)), - inner_pool(0), model_type(gcpp::Model::GEMMA_2B), model(tokenizer, weights, model_type, pool) { kv_cache = CreateKVCache(model_type); @@ -60,8 +59,8 @@ class GemmaTest : public ::testing::Test { gcpp::GenerateGemma( model, /*max_tokens=*/3072, /*max_generated_tokens=*/2048, /*temperature=*/1.0, prompt, /*start_pos=*/0, kv_cache, pool, - inner_pool, stream_token, - /*accept=*/[](int) { return true; }, gen, /*verbosity=*/0); + stream_token, /*accept=*/[](int) { return true; }, gen, + /*verbosity=*/0); std::string response_text; HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text)); return response_text; @@ -71,8 +70,7 @@ class GemmaTest : public ::testing::Test { std::vector prompt; HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt)); return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt, - kv_cache, pool, inner_pool, - /*verbosity=*/0) / + kv_cache, pool, /*verbosity=*/0) / prompt_string.size(); } @@ -89,7 +87,6 @@ class GemmaTest : public ::testing::Test { gcpp::Path tokenizer; gcpp::KVCache kv_cache; hwy::ThreadPool pool; - hwy::ThreadPool inner_pool; gcpp::Model model_type = {}; gcpp::Gemma model; }; diff --git a/gemma/ops.h b/gemma/ops.h index 1988bd8..c0cceec 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -129,11 +129,13 @@ HWY_INLINE void ToEvenOddF32( } // Simple version without tiling nor threading. +// even_odd is precomputed for the current thread. template HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add, + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out) { PROFILER_ZONE("MatVecAddLoop"); const hn::ScalableTag df; @@ -149,7 +151,6 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } - #if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16 template @@ -157,33 +158,39 @@ HWY_INLINE void MatVecAddLoop( const CompressedArray& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, - const AddT* HWY_RESTRICT add, + const AddT* HWY_RESTRICT add, float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out) { PROFILER_ZONE("MatVecAddLoop"); + // Sanity check: we can write without race conditions. + if (HWY_IS_TSAN) { + even_odd[0] = hwy::ConvertScalarTo(vec_aligned[0]); + even_odd[kInner - 1] = -even_odd[0]; + } + const hn::ScalableTag df; - - const auto vec_dequant = hwy::AllocateAligned(kInner); - ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); - + ToEvenOddF32(vec_aligned, kInner, even_odd); for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { const size_t row_ofs = mat_ofs + idx_row * kInner; if constexpr (kAdd) { out[idx_row] = hwy::ConvertScalarTo(add[idx_row]) + - Dot(df, mat, row_ofs, vec_dequant.get(), kInner); + Dot(df, mat, row_ofs, even_odd, kInner); } else { - out[idx_row] = Dot(df, mat, row_ofs, vec_dequant.get(), kInner); + out[idx_row] = Dot(df, mat, row_ofs, even_odd, kInner); } } } #endif +// even_odd is precomputed for the current thread. template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out) { - MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), out); + MatVecAddLoop( + mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), even_odd, + out); } // Simple version without tiling nor threading, but two offsets/outputs. @@ -221,7 +228,7 @@ HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1) { - TwoOfsMatVecAddLoop( + TwoOfsMatVecAddLoop( mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, out0, out1); } @@ -307,11 +314,21 @@ template df; constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + // Sanity check: each thread can write without race conditions. + if (HWY_IS_TSAN) { + pool.Run( + 0, pool.NumWorkers(), [even_odd](uint64_t /*task*/, size_t thread) { + even_odd[thread * kInner] = -static_cast(thread); + even_odd[thread * kInner + kInner - 1] = static_cast(thread); + }); + } + // For each entire strip. pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("MatVec.lambda"); @@ -340,6 +357,7 @@ template ::kSupportsEvenOdd && hwy::IsSameEither() ) { - const auto vec_dequant = hwy::AllocateAligned(kInner); - ToEvenOddF32(vec_aligned, kInner, vec_dequant.get()); + ToEvenOddF32(vec_aligned, kInner, even_odd); detail::MatVecAddInner( - mat, mat_ofs, vec_dequant.get(), add, out, pool); + mat, mat_ofs, even_odd, add, even_odd, out, pool); return; } #endif detail::MatVecAddInner( - mat, mat_ofs, vec_aligned, add, out, pool); + mat, mat_ofs, vec_aligned, add, even_odd, out, pool); } template HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - MatVecAdd( - mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), out, - pool); + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + MatVecAdd( + mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), even_odd, + out, pool); } template @@ -523,7 +541,7 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, hwy::ThreadPool& pool) { - TwoMatVecAdd( + TwoMatVecAdd( mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, out0, out1, pool); } diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 06ef6ef..6a26cfd 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -17,11 +17,15 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif +#include #include #include +#include +#include "compression/compress.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -375,6 +379,7 @@ CompressedArray GenerateMat(size_t offset) { template hwy::AlignedFreeUniquePtr GenerateVec(size_t offset) { hwy::AlignedFreeUniquePtr vec = hwy::AllocateAligned(length); + HWY_ASSERT(vec); for (size_t idx = 0; idx < length; idx++) { vec[idx] = static_cast(idx + offset); } @@ -388,8 +393,9 @@ hwy::AlignedFreeUniquePtr SimpleMatVecAdd( const hwy::AlignedFreeUniquePtr& add) { hwy::AlignedFreeUniquePtr uncompressed_mat = hwy::AllocateAligned(kOuter * kInner); - Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); hwy::AlignedFreeUniquePtr out = hwy::AllocateAligned(kOuter); + HWY_ASSERT(uncompressed_mat && out); + Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { out[idx_row] = add[idx_row]; for (size_t idx_col = 0; idx_col < kInner; idx_col++) { @@ -418,12 +424,15 @@ void TestMatVecAdd() { CompressedArray mat = GenerateMat(0); hwy::AlignedFreeUniquePtr vec = GenerateVec(0); hwy::AlignedFreeUniquePtr add = GenerateVec(0); + hwy::AlignedFreeUniquePtr even_odd = + hwy::AllocateAligned(kInner * pool.NumWorkers()); hwy::AlignedFreeUniquePtr expected_out = SimpleMatVecAdd(mat, vec, add); hwy::AlignedFreeUniquePtr actual_out = hwy::AllocateAligned(kOuter); - MatVecAdd(mat, 0, vec.get(), add.get(), - actual_out.get(), pool); + HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); + MatVecAdd( + mat, 0, vec.get(), add.get(), even_odd.get(), actual_out.get(), pool); AssertClose(actual_out, expected_out); } @@ -433,12 +442,15 @@ void TestMatVecAddLoop() { CompressedArray mat = GenerateMat(0); hwy::AlignedFreeUniquePtr vec = GenerateVec(0); hwy::AlignedFreeUniquePtr add = GenerateVec(0); + hwy::AlignedFreeUniquePtr even_odd = + hwy::AllocateAligned(kInner); hwy::AlignedFreeUniquePtr expected_out = SimpleMatVecAdd(mat, vec, add); hwy::AlignedFreeUniquePtr actual_out = hwy::AllocateAligned(kOuter); + HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); MatVecAddLoop(mat, 0, vec.get(), add.get(), - actual_out.get()); + even_odd.get(), actual_out.get()); AssertClose(actual_out, expected_out); } @@ -459,6 +471,8 @@ void TestTwoMatVecAdd() { hwy::AllocateAligned(kOuter); hwy::AlignedFreeUniquePtr actual_out1 = hwy::AllocateAligned(kOuter); + HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && + expected_out1 && actual_out1); TwoMatVecAdd(mat0, mat1, 0, vec.get(), add0.get(), add1.get(), actual_out0.get(), actual_out1.get(), pool); @@ -481,6 +495,8 @@ void TestTwoOfsMatVecAddLoop() { hwy::AllocateAligned(kOuter); hwy::AlignedFreeUniquePtr actual_out1 = hwy::AllocateAligned(kOuter); + HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && + expected_out1 && actual_out1); TwoOfsMatVecAddLoop(mat, 0, 0, vec.get(), add0.get(), add1.get(), actual_out0.get(), actual_out1.get()); diff --git a/gemma/run.cc b/gemma/run.cc index aa5f9bb..942bff8 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -94,9 +94,8 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, void ReplGemma(gcpp::Gemma& model, ModelTraining training, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -209,7 +208,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, const double time_start = hwy::platform::Now(); GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, + args.temperature, prompt, abs_pos, kv_cache, pool, stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); @@ -229,7 +228,6 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { PROFILER_ZONE("Run.misc"); - hwy::ThreadPool inner_pool(0); hwy::ThreadPool pool(app.num_threads); // For many-core, pinning threads to cores helps. if (app.num_threads > 10) { @@ -271,8 +269,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, loader.ModelTraining(), kv_cache, pool, inner_pool, inference, - app.verbosity, + model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); } diff --git a/util/app.h b/util/app.h index 6541688..6f789e6 100644 --- a/util/app.h +++ b/util/app.h @@ -96,8 +96,9 @@ class AppArgs : public ArgsBase { } static inline size_t GetSupportedThreadCount() { - return static_cast(std::clamp( - static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + return static_cast( + std::clamp(static_cast(std::thread::hardware_concurrency()) - 2, 1, + HWY_MIN(static_cast(kMaxThreads), 18))); } Path log; // output