From 68973130807006adebe76d3feb99f2d6ab529543 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 4 Jun 2025 01:18:20 -0700 Subject: [PATCH] 3x speedup of EmbedImagePatches - GEMM, not GEMV. Required fixes to handling of non-vector aligned A. Also move row ptrs to MatMulEnv. PiperOrigin-RevId: 767029036 --- gemma/activations.h | 17 +++++++---------- gemma/gemma-inl.h | 34 +++++++++------------------------- gemma/run.cc | 1 + ops/matmul-inl.h | 15 ++++++++------- ops/matmul.h | 6 ++++++ ops/matmul_test.cc | 11 +++++++---- paligemma/paligemma_test.cc | 7 +------ 7 files changed, 39 insertions(+), 52 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 7a94960..7563617 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -90,13 +90,13 @@ struct Activations { // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. - q.AllocateAndAttachRowPtrs(row_ptrs); - logits.AllocateAndAttachRowPtrs(row_ptrs); - att_sums.AllocateAndAttachRowPtrs(row_ptrs); - C1.AllocateAndAttachRowPtrs(row_ptrs); - C2.AllocateAndAttachRowPtrs(row_ptrs); - ffw_out.AllocateAndAttachRowPtrs(row_ptrs); - // TODO: also init rows for image_tokens. + x.AllocateAndAttachRowPtrs(env->row_ptrs); + q.AllocateAndAttachRowPtrs(env->row_ptrs); + logits.AllocateAndAttachRowPtrs(env->row_ptrs); + att_sums.AllocateAndAttachRowPtrs(env->row_ptrs); + C1.AllocateAndAttachRowPtrs(env->row_ptrs); + C2.AllocateAndAttachRowPtrs(env->row_ptrs); + ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs); // Note that BindC on any MatMul output considerably slows down Prefill. } @@ -160,9 +160,6 @@ struct Activations { MatStorageT inv_timescale_global; MatMulEnv* env; - // Per-tensor allocations to make it likelier that asan detects bugs such as - // use after free, overrun, and dangling references. - std::vector> row_ptrs; }; } // namespace gcpp diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 5fbf85a..0c946e3 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1105,34 +1105,18 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); HWY_DASSERT(activations.x.Cols() == model_dim); - std::vector> image_patches(seq_len); - for (size_t i = 0; i < seq_len; ++i) { - image_patches[i] = hwy::AllocateAligned(patch_size); - image.GetPatch(i, image_patches[i].get()); - } // img/embedding/kernel has original shape (14, 14, 3, 1152) // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3) - // This could be done as one MatMul like: - // MatStorageT image_patches("patches", Extents2D(kSeqLen, - // kPatchSize), MatPadding::kPacked); - // [Get patches] - // CallMatMul( - // image_patches, - // weights.vit_img_embedding_kernel, - // weights.vit_img_embedding_bias.PackedScale1(), *activations.env, - // activations.x); - // However, MatMul currently requires that - // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 - // which is not the case here. We should relax that requirement on MatMul and - // then use the above. For now, we rely on MatVecAdd instead. - CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) { - for (size_t i = 0; i < seq_len; ++i) { - MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(), - weights.vit_img_embedding_bias.PackedScale1(), - activations.x.Row(i), activations.env->ctx.pools.Pool(0)); - } - }); + // Must be padded, see `DoDecompressA`. + MatStorageT image_patches("patches", Extents2D(seq_len, patch_size), + MatPadding::kOdd); + for (size_t i = 0; i < seq_len; ++i) { + image.GetPatch(i, image_patches.Row(i)); + } + CallMatMul(image_patches, weights.vit_img_embedding_kernel, + weights.vit_img_embedding_bias.PackedScale1(), *activations.env, + activations.x); // Add position embeddings. CallUpcastedActivation(&weights.vit_img_pos_embedding, [&](const auto* weights_t) { diff --git a/gemma/run.cc b/gemma/run.cc index 2afbecb..bacae8f 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -109,6 +109,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, config.model_dim) : Extents2D(0, 0), MatPadding::kOdd); + image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs); if (have_image) { HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA || config.wrapping == PromptWrapping::GEMMA_VLM); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 0d7b664..9be6f4c 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1106,8 +1106,8 @@ class MMPerPackage { }); } - // Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable - // type (i.e., not NUQ) so we can use pointer arithmetic. + // Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a + // seekable type (i.e., not NUQ) so we can use pointer arithmetic. template HWY_NOINLINE void DoDecompressA(const MatPtrT& A, MMParA par_a) const { const IndexRange all_M(0, A.Rows()); @@ -1122,8 +1122,9 @@ class MMPerPackage { const IndexRange& range_K) HWY_ATTR { const size_t col0 = range_K.begin(); const size_t cols = range_K.Num(); - // otherwise, padding overwrites neighbors - HWY_DASSERT(cols % NBF == 0 || cols == A.Cols()); + // Must be a vector multiple, or the last range before row padding, + // otherwise `DecompressAndZeroPad` overwrites neighbors. + HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); for (size_t row_a : range_M) { const PackedSpan from = MakeSpan(A.Row(row_a) + col0, cols); BF16* HWY_RESTRICT to = A_.Row(row_a) + col0; @@ -1169,9 +1170,9 @@ class MMPerPackage { MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; // If already BF16, maybe return a view: if constexpr (hwy::IsSame()) { - // Only if no zero-padding required. + // Only if vector multiple and padded (see `DoDecompressA`). const size_t NBF = hn::Lanes(hn::ScalableTag()); - if (HWY_LIKELY(A.Cols() % NBF == 0)) { + if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) { // Actually const, but RowPtr is also used for partial which is not. return RowPtrBF(const_cast(A.Row(0)), A.Cols(), A.Stride()); } @@ -1241,7 +1242,7 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; - RowPtrBF A_; // points into A or pkg_A. + RowPtrBF A_; // view into A or pkg_A_, both of which are padded. const IndexRange range_np_; // From MMConfig: diff --git a/ops/matmul.h b/ops/matmul.h index c82956c..4e323d4 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -240,6 +240,7 @@ class MMStorage { // Same stride independent of the actual C.Cols() so we can pre-bind. partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { // Per-package allocation so each can decompress A into its own copy. + // Must be padded, see `DoDecompressA`. parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { pkg_A_[pkg_idx].reset(new MatStorageT( "pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd)); @@ -665,6 +666,11 @@ struct MatMulEnv { MMStorage storage; MMKeys keys; std::vector per_key; + + // Pass to MatPtr::AllocateAndAttachRowPtrs. + // Per-tensor allocations to make it likelier that asan detects bugs such as + // use after free, overrun, and dangling references. + std::vector> row_ptrs; }; // Arguments to MatMul() that are independent of the A/B/C types. diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 6d3cf54..112576a 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -91,7 +91,8 @@ float MaxAbs(const MatStorageT& a) { // B is already transposed. template void AssertClose(const MatPtrT& A, const MatPtrT& B, - const MatPtrT& C_slow, const MatPtrT& C, int line) { + const MatPtrT& C_slow, const MatPtrT& C, + MatMulEnv& env, int line) { const hn::ScalableTag df; const size_t cols = A.Cols(); const size_t B_rows = B.Rows(); @@ -101,6 +102,7 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, MatPadding::kOdd); MatStorageT c_batch("c_batch", Extents2D(A.Rows(), B_rows), MatPadding::kOdd); + c_batch.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows), MatPadding::kOdd); for (size_t m = 0; m < A.Rows(); ++m) { @@ -225,8 +227,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatStorageT A(GenerateMat(A_extents, pool)); MatStorageT BT(GenerateTransposedMat(B_extents, pool)); - MatStorageT C_slow("c_slow_batch", C_extents, MatPadding::kOdd); - MatStorageT C("c_batch", C_extents, MatPadding::kOdd); + MatStorageT C_slow("C_slow", C_extents, MatPadding::kOdd); + MatStorageT C("C", C_extents, MatPadding::kOdd); + C.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT add_storage = add ? GenerateMat(Extents2D(1, cols_bc), pool) @@ -238,7 +241,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, // A few reps to get coverage of the various autotuned code paths. for (size_t rep = 0; rep < 16; ++rep) { MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C); - AssertClose(A, BT, C_slow, C, line); + AssertClose(A, BT, C_slow, C, env, line); if (per_key->autotune.Best()) break; } } diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 06f930a..f5aebef 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -44,7 +44,6 @@ class PaliGemmaTest : public ::testing::Test { void TestQuestion(const char* question, const char* expected_substring); std::unique_ptr image_tokens_; - std::vector image_row_ptrs_; }; void PaliGemmaTest::InitVit(const std::string& path) { @@ -54,11 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) { image_tokens_ = std::make_unique( "image", Extents2D(config.vit_config.seq_len, config.model_dim), MatPadding::kPacked); - image_row_ptrs_.resize(image_tokens_->Rows()); - for (size_t r = 0; r < image_tokens_->Rows(); ++r) { - image_row_ptrs_[r] = image_tokens_->RowBytes(r); - } - image_tokens_->AttachRowPtrs(image_row_ptrs_.data()); + image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs); Image image; HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path));