diff --git a/BUILD.bazel b/BUILD.bazel index bca3fd8..b3c1090 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -46,8 +46,10 @@ cc_test( deps = [ ":ops", "@googletest//:gtest_main", + "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", + "@hwy//:thread_pool", ], ) diff --git a/gemma/configs.h b/gemma/configs.h index bedecee..5bfa518 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -28,6 +28,11 @@ #define GEMMA_TOPK 1 #endif // !GEMMA_TOPK +// Allow changing upper bound on threads as a compiler flag +#ifndef GEMMA_MAX_THREADS +#define GEMMA_MAX_THREADS 128 +#endif // !GEMMA_MAX_THREADS + #include #include @@ -45,6 +50,7 @@ namespace gcpp { static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kTopK = GEMMA_TOPK; +static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS; enum class LayerAttentionType { kGemma, diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 05d1aaf..a92d835 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -421,6 +421,10 @@ struct Activations { std::array ffw_out; std::array logits; + // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into + // per-thread storage. + std::array even_odd; + // Griffin layer internal activations static constexpr size_t kGriffinDim = TConfig::kGriffinLayers > 0 ? kModelDim : 0; @@ -575,13 +579,14 @@ HWY_NOINLINE void GriffinRecurrent( gcpp::Activations::kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kHeads = TConfig::kHeads; + static constexpr bool kAdd = true; const size_t batch_offset = batch_idx * kModelDim; const size_t pos = batch_start + batch_idx; // X / Y linear layers. float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - TwoMatVecAdd( + TwoMatVecAdd( layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, activations.pre_att_rms_out.data() + batch_offset, /*add0=*/layer_weights->griffin.linear_x_biases.data(), @@ -631,7 +636,7 @@ HWY_NOINLINE void GriffinRecurrent( constexpr size_t kHeadDim = kModelDim / kHeads; constexpr size_t kMatrixSize = kHeadDim * kHeadDim; size_t head_offset = head * kHeadDim; - TwoOfsMatVecAddLoop( + TwoOfsMatVecAddLoop( layer_weights->griffin.gate_w, kMatrixSize * head, kMatrixSize * (kHeads + head), x + head_offset, /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, @@ -670,9 +675,10 @@ HWY_NOINLINE void GriffinRecurrent( // Final linear layer. float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; - MatVecAdd( + MatVecAdd( layer_weights->griffin.linear_out_w, 0, x, - layer_weights->griffin.linear_out_biases.data(), out_ptr, pool); + layer_weights->griffin.linear_out_biases.data(), + activations.even_odd.data(), out_ptr, pool); } template @@ -704,26 +710,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - - MatVecLoop(layer_weights->qkv_einsum_w, - head_offset + 0 * kQKVDim * kModelDim, x, q); - }; - - auto ProjKV = [&](size_t k_offset, size_t v_offset, - size_t kv_offset) HWY_ATTR { - float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; - float* HWY_RESTRICT v = k + kQKVDim; - - TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, k_offset, - v_offset, x, k, v); - - Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - }; - - auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { + auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR { // Calculate scores float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; @@ -760,20 +747,21 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, head == 0 ? activations.att_post2.data() + batch_idx * kModelDim : activations.att_post1.data() + head * kBatchSize * kModelDim; + float* even_odd = activations.even_odd.data() + thread * kQKVDim; if (head == 0) { MatVecAddLoop( layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, - layer_weights->attention_output_biases.data(), head_out); + layer_weights->attention_output_biases.data(), even_odd, head_out); } else { MatVecLoop(layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, - head_out); + even_odd, head_out); } }; if constexpr (kHeads == kKVHeads) { // Multi-Head Attention - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { // linear projections to QKV const size_t head_offset = TConfig::kInterleaveQKV ? 3 * kQKVDim * kModelDim @@ -784,32 +772,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t k_offset = head * head_offset + 1 * mat_offset; const size_t v_offset = head * head_offset + 2 * mat_offset; - ProjQ(head, q_offset); + // ProjQ + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; + MatVecLoop( + layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x, + activations.even_odd.data() + thread * kModelDim, q); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + - head * kQKVDim * 2; + // ProjKV + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head * kQKVDim * 2; + float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; + float* HWY_RESTRICT v = k + kQKVDim; + TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, + k_offset, v_offset, x, k, v); + Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - ProjKV(k_offset, v_offset, kv_offset); - - Attn(head, head * kQKVDim * 2); + Attn(head, head * kQKVDim * 2, thread); }); } else { // Multi-Query Attention float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; - MatVec(layer_weights->qkv_einsum_w, 0, x, q, - pool); + MatVec(layer_weights->qkv_einsum_w, 0, x, + activations.even_odd.data(), q, pool); float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + cache_pos * kCachePosSize + layer * kCacheLayerSize; MatVec(layer_weights->qkv_einsum_w, - kHeads * kQKVDim * kModelDim, x, kv, pool); + kHeads * kQKVDim * kModelDim, x, + activations.even_odd.data(), kv, pool); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - Attn(head, 0); + pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { + Attn(head, 0, thread); }); } @@ -829,6 +826,7 @@ HWY_NOINLINE void FFW(Activations& activations, static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; + float* HWY_RESTRICT even_odd = activations.even_odd.data(); { PROFILER_ZONE("Gen.FFW.GatedGELU"); @@ -837,15 +835,15 @@ HWY_NOINLINE void FFW(Activations& activations, float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - // Same matrix, first and second half of rows. Could fuse into one MatVec, - // but separating them could help on NUMA e.g. multiple sockets. + // Same matrix, first and second half of rows. Could fuse into one MatVec. MatVecAdd( layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool); + layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, + out_mul, pool); // Gate, will go through the nonlinearity. MatVecAdd( layer_weights->gating_einsum_w, 0, vec, - layer_weights->ffw_gating_biases.data(), out, pool); + layer_weights->ffw_gating_biases.data(), even_odd, out, pool); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -858,7 +856,7 @@ HWY_NOINLINE void FFW(Activations& activations, PROFILER_ZONE("Gen.FFW\\GatedGELU"); MatVecAdd( layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset, - layer_weights->ffw_output_biases.data(), + layer_weights->ffw_output_biases.data(), even_odd, activations.ffw_out.data() + batch_idx * kModelDim, pool); } @@ -1110,9 +1108,9 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, if (is_generating_phase) { PROFILER_ZONE("Gen.Embedding"); // Generation phase - MatVec(weights.embedder_input_embedding, - 0, final_activation, - activations.logits.data(), pool); + MatVec( + weights.embedder_input_embedding, 0, final_activation, + activations.even_odd.data(), activations.logits.data(), pool); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, @@ -1193,9 +1191,9 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, } Transformer(token, pos, weights, activations, kv_cache, pool, /*layers_output=*/nullptr); - MatVec(weights.embedder_input_embedding, 0, - activations.x.data(), - activations.logits.data(), pool); + MatVec( + weights.embedder_input_embedding, 0, activations.x.data(), + activations.even_odd.data(), activations.logits.data(), pool); LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); memcpy(logits.data(), activations.logits.data(), kVocabSize * sizeof(logits[0])); diff --git a/gemma/ops.h b/gemma/ops.h index da6a38e..bac98fa 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -93,15 +93,23 @@ HWY_INLINE constexpr size_t RowsPerStrip() { } // Simple version without tiling nor threading. +// even_odd is precomputed for the current thread. template HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add, + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out) { PROFILER_ZONE("MatVecAddLoop"); const hn::ScalableTag df; + // Sanity check: we can write without race conditions. + if (HWY_IS_TSAN) { + even_odd[0] = hwy::ConvertScalarTo(vec_aligned[0]); + even_odd[kInner - 1] = -even_odd[0]; + } + for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { const size_t row_ofs = mat_ofs + idx_row * kInner; if constexpr (kAdd) { @@ -113,12 +121,14 @@ HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, } } +// even_odd is precomputed for the current thread. template HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out) { - MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, out); + MatVecAddLoop( + mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out); } // Simple version without tiling nor threading, but two offsets/outputs. @@ -156,7 +166,7 @@ HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1) { - TwoOfsMatVecAddLoop( + TwoOfsMatVecAddLoop( mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, out0, out1); } @@ -237,19 +247,29 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, // Stores dot products of rows with `vec_aligned` + add the values from `add` // (if kAdd), then stores them to `out`. -// +// `even_odd` has kInner elements for each thread. template HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { PROFILER_ZONE("MatVecAdd"); const hn::ScalableTag df; constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; + // Sanity check: each thread can write without race conditions. + if (HWY_IS_TSAN) { + pool.Run( + 0, pool.NumWorkers(), [even_odd](uint64_t /*task*/, size_t thread) { + even_odd[thread * kInner] = -static_cast(thread); + even_odd[thread * kInner + kInner - 1] = static_cast(thread); + }); + } + // For each entire strip. pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("MatVec.lambda"); @@ -272,9 +292,10 @@ HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, template HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - MatVecAdd( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, out, pool); + float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + MatVecAdd( + mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out, pool); } template @@ -427,7 +448,7 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, hwy::ThreadPool& pool) { - TwoMatVecAdd( + TwoMatVecAdd( mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, out0, out1, pool); } diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 06ef6ef..6a26cfd 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -17,11 +17,15 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif +#include #include #include +#include +#include "compression/compress.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -375,6 +379,7 @@ CompressedArray GenerateMat(size_t offset) { template hwy::AlignedFreeUniquePtr GenerateVec(size_t offset) { hwy::AlignedFreeUniquePtr vec = hwy::AllocateAligned(length); + HWY_ASSERT(vec); for (size_t idx = 0; idx < length; idx++) { vec[idx] = static_cast(idx + offset); } @@ -388,8 +393,9 @@ hwy::AlignedFreeUniquePtr SimpleMatVecAdd( const hwy::AlignedFreeUniquePtr& add) { hwy::AlignedFreeUniquePtr uncompressed_mat = hwy::AllocateAligned(kOuter * kInner); - Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); hwy::AlignedFreeUniquePtr out = hwy::AllocateAligned(kOuter); + HWY_ASSERT(uncompressed_mat && out); + Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { out[idx_row] = add[idx_row]; for (size_t idx_col = 0; idx_col < kInner; idx_col++) { @@ -418,12 +424,15 @@ void TestMatVecAdd() { CompressedArray mat = GenerateMat(0); hwy::AlignedFreeUniquePtr vec = GenerateVec(0); hwy::AlignedFreeUniquePtr add = GenerateVec(0); + hwy::AlignedFreeUniquePtr even_odd = + hwy::AllocateAligned(kInner * pool.NumWorkers()); hwy::AlignedFreeUniquePtr expected_out = SimpleMatVecAdd(mat, vec, add); hwy::AlignedFreeUniquePtr actual_out = hwy::AllocateAligned(kOuter); - MatVecAdd(mat, 0, vec.get(), add.get(), - actual_out.get(), pool); + HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); + MatVecAdd( + mat, 0, vec.get(), add.get(), even_odd.get(), actual_out.get(), pool); AssertClose(actual_out, expected_out); } @@ -433,12 +442,15 @@ void TestMatVecAddLoop() { CompressedArray mat = GenerateMat(0); hwy::AlignedFreeUniquePtr vec = GenerateVec(0); hwy::AlignedFreeUniquePtr add = GenerateVec(0); + hwy::AlignedFreeUniquePtr even_odd = + hwy::AllocateAligned(kInner); hwy::AlignedFreeUniquePtr expected_out = SimpleMatVecAdd(mat, vec, add); hwy::AlignedFreeUniquePtr actual_out = hwy::AllocateAligned(kOuter); + HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); MatVecAddLoop(mat, 0, vec.get(), add.get(), - actual_out.get()); + even_odd.get(), actual_out.get()); AssertClose(actual_out, expected_out); } @@ -459,6 +471,8 @@ void TestTwoMatVecAdd() { hwy::AllocateAligned(kOuter); hwy::AlignedFreeUniquePtr actual_out1 = hwy::AllocateAligned(kOuter); + HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && + expected_out1 && actual_out1); TwoMatVecAdd(mat0, mat1, 0, vec.get(), add0.get(), add1.get(), actual_out0.get(), actual_out1.get(), pool); @@ -481,6 +495,8 @@ void TestTwoOfsMatVecAddLoop() { hwy::AllocateAligned(kOuter); hwy::AlignedFreeUniquePtr actual_out1 = hwy::AllocateAligned(kOuter); + HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && + expected_out1 && actual_out1); TwoOfsMatVecAddLoop(mat, 0, 0, vec.get(), add0.get(), add1.get(), actual_out0.get(), actual_out1.get()); diff --git a/util/app.h b/util/app.h index 6541688..6f789e6 100644 --- a/util/app.h +++ b/util/app.h @@ -96,8 +96,9 @@ class AppArgs : public ArgsBase { } static inline size_t GetSupportedThreadCount() { - return static_cast(std::clamp( - static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + return static_cast( + std::clamp(static_cast(std::thread::hardware_concurrency()) - 2, 1, + HWY_MIN(static_cast(kMaxThreads), 18))); } Path log; // output