diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 6d80741..28de0c8 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -402,10 +402,11 @@ struct Activations { static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; static constexpr size_t kCachePosSize = TConfig::kGemmaLayers * kCacheLayerSize; + static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim; std::array x; // input std::array pre_att_rms_out; - std::array q; // query vector + std::array q; // query vector std::array att; // attention vector std::array att_out; // attention output @@ -710,10 +711,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - auto Attn = [&](uint64_t head, size_t head_offset, size_t thread) HWY_ATTR { + auto Attn = [&](float* q, uint64_t head, size_t head_offset, + size_t thread) HWY_ATTR { // Calculate scores - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; float* HWY_RESTRICT head_att = activations.att.data() + head * TConfig::kSeqLen + batch_idx * kHeads * kQKVDim; @@ -745,34 +745,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, if constexpr (kHeads == kKVHeads) { // Multi-Head Attention + static_assert(TConfig::kInterleaveQKV); + + float* HWY_RESTRICT qkv = + activations.q.data() + batch_idx * kHeads * kQKVDim * 3; + MatVec( + layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, + pool); + pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { - // linear projections to QKV - const size_t head_offset = TConfig::kInterleaveQKV - ? 3 * kQKVDim * kModelDim - : kQKVDim * kModelDim; - const size_t mat_offset = - TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim; - const size_t q_offset = head * head_offset + 0 * mat_offset; - const size_t k_offset = head * head_offset + 1 * mat_offset; - const size_t v_offset = head * head_offset + 2 * mat_offset; - - // ProjQ - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - MatVecLoop( - layer_weights->qkv_einsum_w, q_offset + 0 * kQKVDim * kModelDim, x, - activations.even_odd.data() + thread * kModelDim, q); - - // ProjKV + float* HWY_RESTRICT q = qkv + head * kQKVDim * 3; const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim * 2; - float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; - float* HWY_RESTRICT v = k + kQKVDim; - TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, - k_offset, v_offset, x, k, v); - Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - Attn(head, head * kQKVDim * 2, thread); + memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + Attn(q, head, head * kQKVDim * 2, thread); }); } else { // Multi-Query Attention @@ -790,7 +779,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); pool.Run(0, kHeads, [&](const uint64_t head, size_t thread) HWY_ATTR { - Attn(head, 0, thread); + Attn(q + head * kQKVDim, head, 0, thread); }); } diff --git a/gemma/ops.h b/gemma/ops.h index bac98fa..e40e72a 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -92,45 +92,6 @@ HWY_INLINE constexpr size_t RowsPerStrip() { return kRowsPerStrip; } -// Simple version without tiling nor threading. -// even_odd is precomputed for the current thread. -template -HWY_INLINE void MatVecAddLoop(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT vec_aligned, - const AddT* HWY_RESTRICT add, - float* HWY_RESTRICT even_odd, - float* HWY_RESTRICT out) { - PROFILER_ZONE("MatVecAddLoop"); - const hn::ScalableTag df; - - // Sanity check: we can write without race conditions. - if (HWY_IS_TSAN) { - even_odd[0] = hwy::ConvertScalarTo(vec_aligned[0]); - even_odd[kInner - 1] = -even_odd[0]; - } - - for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { - const size_t row_ofs = mat_ofs + idx_row * kInner; - if constexpr (kAdd) { - out[idx_row] = hwy::ConvertScalarTo(add[idx_row]) + - Dot(df, mat, row_ofs, vec_aligned, kInner); - } else { - out[idx_row] = Dot(df, mat, row_ofs, vec_aligned, kInner); - } - } -} - -// even_odd is precomputed for the current thread. -template -HWY_INLINE void MatVecLoop(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT vec_aligned, - float* HWY_RESTRICT even_odd, - float* HWY_RESTRICT out) { - MatVecAddLoop( - mat, mat_ofs, vec_aligned, /*add=*/nullptr, even_odd, out); -} - // Simple version without tiling nor threading, but two offsets/outputs. template @@ -159,18 +120,6 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0, } } -// Simple version without tiling nor threading, but two offsets/outputs. -template -HWY_INLINE void TwoOfsMatVecLoop(const ArrayT& mat, const size_t mat_ofs0, - const size_t mat_ofs1, - const VecT* HWY_RESTRICT vec_aligned, - float* HWY_RESTRICT out0, - float* HWY_RESTRICT out1) { - TwoOfsMatVecAddLoop( - mat, mat_ofs0, mat_ofs1, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, - out0, out1); -} - namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 6a26cfd..973d598 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -436,24 +436,6 @@ void TestMatVecAdd() { AssertClose(actual_out, expected_out); } -void TestMatVecAddLoop() { - constexpr size_t kOuter = 128 * 3; - constexpr size_t kInner = 128 * 5; - CompressedArray mat = GenerateMat(0); - hwy::AlignedFreeUniquePtr vec = GenerateVec(0); - hwy::AlignedFreeUniquePtr add = GenerateVec(0); - hwy::AlignedFreeUniquePtr even_odd = - hwy::AllocateAligned(kInner); - hwy::AlignedFreeUniquePtr expected_out = - SimpleMatVecAdd(mat, vec, add); - hwy::AlignedFreeUniquePtr actual_out = - hwy::AllocateAligned(kOuter); - HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); - MatVecAddLoop(mat, 0, vec.get(), add.get(), - even_odd.get(), actual_out.get()); - AssertClose(actual_out, expected_out); -} - void TestTwoMatVecAdd() { hwy::ThreadPool pool(0); constexpr size_t kOuter = 128 * 3; @@ -537,7 +519,6 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); -HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAddLoop); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);