3x speedup of EmbedImagePatches - GEMM, not GEMV.

Required fixes to handling of non-vector aligned A.
Also move row ptrs to MatMulEnv.

PiperOrigin-RevId: 767029036
This commit is contained in:
Jan Wassenberg 2025-06-04 01:18:20 -07:00 committed by Copybara-Service
parent 9f74a1a098
commit 6897313080
7 changed files with 39 additions and 52 deletions

View File

@ -90,13 +90,13 @@ struct Activations {
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
q.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
// TODO: also init rows for image_tokens.
x.AllocateAndAttachRowPtrs(env->row_ptrs);
q.AllocateAndAttachRowPtrs(env->row_ptrs);
logits.AllocateAndAttachRowPtrs(env->row_ptrs);
att_sums.AllocateAndAttachRowPtrs(env->row_ptrs);
C1.AllocateAndAttachRowPtrs(env->row_ptrs);
C2.AllocateAndAttachRowPtrs(env->row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs);
// Note that BindC on any MatMul output considerably slows down Prefill.
}
@ -160,9 +160,6 @@ struct Activations {
MatStorageT<float> inv_timescale_global;
MatMulEnv* env;
// Per-tensor allocations to make it likelier that asan detects bugs such as
// use after free, overrun, and dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};
} // namespace gcpp

View File

@ -1105,34 +1105,18 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
HWY_DASSERT(activations.x.Cols() == model_dim);
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
for (size_t i = 0; i < seq_len; ++i) {
image_patches[i] = hwy::AllocateAligned<float>(patch_size);
image.GetPatch(i, image_patches[i].get());
}
// img/embedding/kernel has original shape (14, 14, 3, 1152)
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3)
// This could be done as one MatMul like:
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
// kPatchSize), MatPadding::kPacked);
// [Get patches]
// CallMatMul(
// image_patches,
// weights.vit_img_embedding_kernel,
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
// activations.x);
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and
// then use the above. For now, we rely on MatVecAdd instead.
CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) {
for (size_t i = 0; i < seq_len; ++i) {
MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(),
weights.vit_img_embedding_bias.PackedScale1(),
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
}
});
// Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
MatPadding::kOdd);
for (size_t i = 0; i < seq_len; ++i) {
image.GetPatch(i, image_patches.Row(i));
}
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
activations.x);
// Add position embeddings.
CallUpcastedActivation(&weights.vit_img_pos_embedding,
[&](const auto* weights_t) {

View File

@ -109,6 +109,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
config.model_dim)
: Extents2D(0, 0),
MatPadding::kOdd);
image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs);
if (have_image) {
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
config.wrapping == PromptWrapping::GEMMA_VLM);

View File

@ -1106,8 +1106,8 @@ class MMPerPackage {
});
}
// Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable
// type (i.e., not NUQ) so we can use pointer arithmetic.
// Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a
// seekable type (i.e., not NUQ) so we can use pointer arithmetic.
template <typename TA>
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
const IndexRange all_M(0, A.Rows());
@ -1122,8 +1122,9 @@ class MMPerPackage {
const IndexRange& range_K) HWY_ATTR {
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// otherwise, padding overwrites neighbors
HWY_DASSERT(cols % NBF == 0 || cols == A.Cols());
// Must be a vector multiple, or the last range before row padding,
// otherwise `DecompressAndZeroPad` overwrites neighbors.
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
@ -1169,9 +1170,9 @@ class MMPerPackage {
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if no zero-padding required.
// Only if vector multiple and padded (see `DoDecompressA`).
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.Cols() % NBF == 0)) {
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
// Actually const, but RowPtr is also used for partial which is not.
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
}
@ -1241,7 +1242,7 @@ class MMPerPackage {
const MMArgs args_; // copy for locality
const size_t pkg_idx_;
RowPtrBF A_; // points into A or pkg_A.
RowPtrBF A_; // view into A or pkg_A_, both of which are padded.
const IndexRange range_np_;
// From MMConfig:

View File

@ -240,6 +240,7 @@ class MMStorage {
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
// Per-package allocation so each can decompress A into its own copy.
// Must be padded, see `DoDecompressA`.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
@ -665,6 +666,11 @@ struct MatMulEnv {
MMStorage storage;
MMKeys keys;
std::vector<MMPerKey> per_key;
// Pass to MatPtr::AllocateAndAttachRowPtrs.
// Per-tensor allocations to make it likelier that asan detects bugs such as
// use after free, overrun, and dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};
// Arguments to MatMul() that are independent of the A/B/C types.

View File

@ -91,7 +91,8 @@ float MaxAbs(const MatStorageT<float>& a) {
// B is already transposed.
template <typename TA, typename TB, typename TC>
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C, int line) {
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C,
MatMulEnv& env, int line) {
const hn::ScalableTag<float> df;
const size_t cols = A.Cols();
const size_t B_rows = B.Rows();
@ -101,6 +102,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
MatPadding::kOdd);
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
MatPadding::kOdd);
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
MatPadding::kOdd);
for (size_t m = 0; m < A.Rows(); ++m) {
@ -225,8 +227,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStorageT<TA> A(GenerateMat<TA>(A_extents, pool));
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, pool));
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> C_slow("C_slow", C_extents, MatPadding::kOdd);
MatStorageT<TC> C("C", C_extents, MatPadding::kOdd);
C.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
@ -238,7 +241,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
// A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
AssertClose(A, BT, C_slow, C, line);
AssertClose(A, BT, C_slow, C, env, line);
if (per_key->autotune.Best()) break;
}
}

View File

@ -44,7 +44,6 @@ class PaliGemmaTest : public ::testing::Test {
void TestQuestion(const char* question, const char* expected_substring);
std::unique_ptr<ImageTokens> image_tokens_;
std::vector<uint8_t*> image_row_ptrs_;
};
void PaliGemmaTest::InitVit(const std::string& path) {
@ -54,11 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_row_ptrs_.resize(image_tokens_->Rows());
for (size_t r = 0; r < image_tokens_->Rows(); ++r) {
image_row_ptrs_[r] = image_tokens_->RowBytes(r);
}
image_tokens_->AttachRowPtrs(image_row_ptrs_.data());
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
Image image;
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));