mirror of https://github.com/google/gemma.cpp.git
3x speedup of EmbedImagePatches - GEMM, not GEMV.
Required fixes to handling of non-vector aligned A. Also move row ptrs to MatMulEnv. PiperOrigin-RevId: 767029036
This commit is contained in:
parent
9f74a1a098
commit
6897313080
|
|
@ -90,13 +90,13 @@ struct Activations {
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||||
// fill them in each MatMul call.
|
// fill them in each MatMul call.
|
||||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
x.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
q.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
logits.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
att_sums.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
C1.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
C2.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
// TODO: also init rows for image_tokens.
|
ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||||
|
|
||||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||||
}
|
}
|
||||||
|
|
@ -160,9 +160,6 @@ struct Activations {
|
||||||
MatStorageT<float> inv_timescale_global;
|
MatStorageT<float> inv_timescale_global;
|
||||||
|
|
||||||
MatMulEnv* env;
|
MatMulEnv* env;
|
||||||
// Per-tensor allocations to make it likelier that asan detects bugs such as
|
|
||||||
// use after free, overrun, and dangling references.
|
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -1105,34 +1105,18 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
|
||||||
image_patches[i] = hwy::AllocateAligned<float>(patch_size);
|
|
||||||
image.GetPatch(i, image_patches[i].get());
|
|
||||||
}
|
|
||||||
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
||||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||||
// image_patches is (256, 14 * 14 * 3)
|
// image_patches is (256, 14 * 14 * 3)
|
||||||
// This could be done as one MatMul like:
|
// Must be padded, see `DoDecompressA`.
|
||||||
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
|
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
|
||||||
// kPatchSize), MatPadding::kPacked);
|
MatPadding::kOdd);
|
||||||
// [Get patches]
|
|
||||||
// CallMatMul(
|
|
||||||
// image_patches,
|
|
||||||
// weights.vit_img_embedding_kernel,
|
|
||||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
|
||||||
// activations.x);
|
|
||||||
// However, MatMul currently requires that
|
|
||||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
|
||||||
// which is not the case here. We should relax that requirement on MatMul and
|
|
||||||
// then use the above. For now, we rely on MatVecAdd instead.
|
|
||||||
CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) {
|
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(),
|
image.GetPatch(i, image_patches.Row(i));
|
||||||
weights.vit_img_embedding_bias.PackedScale1(),
|
|
||||||
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
|
|
||||||
}
|
}
|
||||||
});
|
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||||
|
weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||||
|
activations.x);
|
||||||
// Add position embeddings.
|
// Add position embeddings.
|
||||||
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
||||||
[&](const auto* weights_t) {
|
[&](const auto* weights_t) {
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
config.model_dim)
|
config.model_dim)
|
||||||
: Extents2D(0, 0),
|
: Extents2D(0, 0),
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
|
image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs);
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||||
config.wrapping == PromptWrapping::GEMMA_VLM);
|
config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||||
|
|
|
||||||
|
|
@ -1106,8 +1106,8 @@ class MMPerPackage {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable
|
// Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a
|
||||||
// type (i.e., not NUQ) so we can use pointer arithmetic.
|
// seekable type (i.e., not NUQ) so we can use pointer arithmetic.
|
||||||
template <typename TA>
|
template <typename TA>
|
||||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
|
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
|
||||||
const IndexRange all_M(0, A.Rows());
|
const IndexRange all_M(0, A.Rows());
|
||||||
|
|
@ -1122,8 +1122,9 @@ class MMPerPackage {
|
||||||
const IndexRange& range_K) HWY_ATTR {
|
const IndexRange& range_K) HWY_ATTR {
|
||||||
const size_t col0 = range_K.begin();
|
const size_t col0 = range_K.begin();
|
||||||
const size_t cols = range_K.Num();
|
const size_t cols = range_K.Num();
|
||||||
// otherwise, padding overwrites neighbors
|
// Must be a vector multiple, or the last range before row padding,
|
||||||
HWY_DASSERT(cols % NBF == 0 || cols == A.Cols());
|
// otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||||
|
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||||
for (size_t row_a : range_M) {
|
for (size_t row_a : range_M) {
|
||||||
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
|
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
|
||||||
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
|
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
|
||||||
|
|
@ -1169,9 +1170,9 @@ class MMPerPackage {
|
||||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||||
// If already BF16, maybe return a view:
|
// If already BF16, maybe return a view:
|
||||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||||
// Only if no zero-padding required.
|
// Only if vector multiple and padded (see `DoDecompressA`).
|
||||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) {
|
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
|
||||||
// Actually const, but RowPtr is also used for partial which is not.
|
// Actually const, but RowPtr is also used for partial which is not.
|
||||||
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
||||||
}
|
}
|
||||||
|
|
@ -1241,7 +1242,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
const MMArgs args_; // copy for locality
|
const MMArgs args_; // copy for locality
|
||||||
const size_t pkg_idx_;
|
const size_t pkg_idx_;
|
||||||
RowPtrBF A_; // points into A or pkg_A.
|
RowPtrBF A_; // view into A or pkg_A_, both of which are padded.
|
||||||
|
|
||||||
const IndexRange range_np_;
|
const IndexRange range_np_;
|
||||||
// From MMConfig:
|
// From MMConfig:
|
||||||
|
|
|
||||||
|
|
@ -240,6 +240,7 @@ class MMStorage {
|
||||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||||
// Per-package allocation so each can decompress A into its own copy.
|
// Per-package allocation so each can decompress A into its own copy.
|
||||||
|
// Must be padded, see `DoDecompressA`.
|
||||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||||
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
|
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
|
||||||
|
|
@ -665,6 +666,11 @@ struct MatMulEnv {
|
||||||
MMStorage storage;
|
MMStorage storage;
|
||||||
MMKeys keys;
|
MMKeys keys;
|
||||||
std::vector<MMPerKey> per_key;
|
std::vector<MMPerKey> per_key;
|
||||||
|
|
||||||
|
// Pass to MatPtr::AllocateAndAttachRowPtrs.
|
||||||
|
// Per-tensor allocations to make it likelier that asan detects bugs such as
|
||||||
|
// use after free, overrun, and dangling references.
|
||||||
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Arguments to MatMul() that are independent of the A/B/C types.
|
// Arguments to MatMul() that are independent of the A/B/C types.
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,8 @@ float MaxAbs(const MatStorageT<float>& a) {
|
||||||
// B is already transposed.
|
// B is already transposed.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C, int line) {
|
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C,
|
||||||
|
MatMulEnv& env, int line) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t cols = A.Cols();
|
const size_t cols = A.Cols();
|
||||||
const size_t B_rows = B.Rows();
|
const size_t B_rows = B.Rows();
|
||||||
|
|
@ -101,6 +102,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
|
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
|
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
|
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
for (size_t m = 0; m < A.Rows(); ++m) {
|
for (size_t m = 0; m < A.Rows(); ++m) {
|
||||||
|
|
@ -225,8 +227,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
|
|
||||||
MatStorageT<TA> A(GenerateMat<TA>(A_extents, pool));
|
MatStorageT<TA> A(GenerateMat<TA>(A_extents, pool));
|
||||||
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, pool));
|
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, pool));
|
||||||
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
|
MatStorageT<TC> C_slow("C_slow", C_extents, MatPadding::kOdd);
|
||||||
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
|
MatStorageT<TC> C("C", C_extents, MatPadding::kOdd);
|
||||||
|
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
|
|
||||||
MatStorageT<float> add_storage =
|
MatStorageT<float> add_storage =
|
||||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
||||||
|
|
@ -238,7 +241,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
// A few reps to get coverage of the various autotuned code paths.
|
// A few reps to get coverage of the various autotuned code paths.
|
||||||
for (size_t rep = 0; rep < 16; ++rep) {
|
for (size_t rep = 0; rep < 16; ++rep) {
|
||||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
|
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
|
||||||
AssertClose(A, BT, C_slow, C, line);
|
AssertClose(A, BT, C_slow, C, env, line);
|
||||||
if (per_key->autotune.Best()) break;
|
if (per_key->autotune.Best()) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,6 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
void TestQuestion(const char* question, const char* expected_substring);
|
void TestQuestion(const char* question, const char* expected_substring);
|
||||||
|
|
||||||
std::unique_ptr<ImageTokens> image_tokens_;
|
std::unique_ptr<ImageTokens> image_tokens_;
|
||||||
std::vector<uint8_t*> image_row_ptrs_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
|
|
@ -54,11 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
image_tokens_ = std::make_unique<ImageTokens>(
|
image_tokens_ = std::make_unique<ImageTokens>(
|
||||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||||
MatPadding::kPacked);
|
MatPadding::kPacked);
|
||||||
image_row_ptrs_.resize(image_tokens_->Rows());
|
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
||||||
for (size_t r = 0; r < image_tokens_->Rows(); ++r) {
|
|
||||||
image_row_ptrs_[r] = image_tokens_->RowBytes(r);
|
|
||||||
}
|
|
||||||
image_tokens_->AttachRowPtrs(image_row_ptrs_.data());
|
|
||||||
Image image;
|
Image image;
|
||||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue