mirror of https://github.com/google/gemma.cpp.git
3x speedup of EmbedImagePatches - GEMM, not GEMV.
Required fixes to handling of non-vector aligned A. Also move row ptrs to MatMulEnv. PiperOrigin-RevId: 767029036
This commit is contained in:
parent
9f74a1a098
commit
6897313080
|
|
@ -90,13 +90,13 @@ struct Activations {
|
|||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
// TODO: also init rows for image_tokens.
|
||||
x.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
q.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(env->row_ptrs);
|
||||
|
||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
|
@ -160,9 +160,6 @@ struct Activations {
|
|||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
MatMulEnv* env;
|
||||
// Per-tensor allocations to make it likelier that asan detects bugs such as
|
||||
// use after free, overrun, and dangling references.
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -1105,34 +1105,18 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
image_patches[i] = hwy::AllocateAligned<float>(patch_size);
|
||||
image.GetPatch(i, image_patches[i].get());
|
||||
}
|
||||
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||
// image_patches is (256, 14 * 14 * 3)
|
||||
// This could be done as one MatMul like:
|
||||
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
|
||||
// kPatchSize), MatPadding::kPacked);
|
||||
// [Get patches]
|
||||
// CallMatMul(
|
||||
// image_patches,
|
||||
// weights.vit_img_embedding_kernel,
|
||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
// activations.x);
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
// then use the above. For now, we rely on MatVecAdd instead.
|
||||
CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) {
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(),
|
||||
weights.vit_img_embedding_bias.PackedScale1(),
|
||||
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
|
||||
}
|
||||
});
|
||||
// Must be padded, see `DoDecompressA`.
|
||||
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
|
||||
MatPadding::kOdd);
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
image.GetPatch(i, image_patches.Row(i));
|
||||
}
|
||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||
weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
activations.x);
|
||||
// Add position embeddings.
|
||||
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
||||
[&](const auto* weights_t) {
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
config.model_dim)
|
||||
: Extents2D(0, 0),
|
||||
MatPadding::kOdd);
|
||||
image_tokens.AllocateAndAttachRowPtrs(gemma.Env().row_ptrs);
|
||||
if (have_image) {
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||
config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
|
|
|
|||
|
|
@ -1106,8 +1106,8 @@ class MMPerPackage {
|
|||
});
|
||||
}
|
||||
|
||||
// Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable
|
||||
// type (i.e., not NUQ) so we can use pointer arithmetic.
|
||||
// Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a
|
||||
// seekable type (i.e., not NUQ) so we can use pointer arithmetic.
|
||||
template <typename TA>
|
||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
|
||||
const IndexRange all_M(0, A.Rows());
|
||||
|
|
@ -1122,8 +1122,9 @@ class MMPerPackage {
|
|||
const IndexRange& range_K) HWY_ATTR {
|
||||
const size_t col0 = range_K.begin();
|
||||
const size_t cols = range_K.Num();
|
||||
// otherwise, padding overwrites neighbors
|
||||
HWY_DASSERT(cols % NBF == 0 || cols == A.Cols());
|
||||
// Must be a vector multiple, or the last range before row padding,
|
||||
// otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||
for (size_t row_a : range_M) {
|
||||
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
|
||||
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
|
||||
|
|
@ -1169,9 +1170,9 @@ class MMPerPackage {
|
|||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
// Only if no zero-padding required.
|
||||
// Only if vector multiple and padded (see `DoDecompressA`).
|
||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) {
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
|
||||
// Actually const, but RowPtr is also used for partial which is not.
|
||||
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
||||
}
|
||||
|
|
@ -1241,7 +1242,7 @@ class MMPerPackage {
|
|||
|
||||
const MMArgs args_; // copy for locality
|
||||
const size_t pkg_idx_;
|
||||
RowPtrBF A_; // points into A or pkg_A.
|
||||
RowPtrBF A_; // view into A or pkg_A_, both of which are padded.
|
||||
|
||||
const IndexRange range_np_;
|
||||
// From MMConfig:
|
||||
|
|
|
|||
|
|
@ -240,6 +240,7 @@ class MMStorage {
|
|||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
|
||||
// Per-package allocation so each can decompress A into its own copy.
|
||||
// Must be padded, see `DoDecompressA`.
|
||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
|
||||
|
|
@ -665,6 +666,11 @@ struct MatMulEnv {
|
|||
MMStorage storage;
|
||||
MMKeys keys;
|
||||
std::vector<MMPerKey> per_key;
|
||||
|
||||
// Pass to MatPtr::AllocateAndAttachRowPtrs.
|
||||
// Per-tensor allocations to make it likelier that asan detects bugs such as
|
||||
// use after free, overrun, and dangling references.
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
};
|
||||
|
||||
// Arguments to MatMul() that are independent of the A/B/C types.
|
||||
|
|
|
|||
|
|
@ -91,7 +91,8 @@ float MaxAbs(const MatStorageT<float>& a) {
|
|||
// B is already transposed.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C, int line) {
|
||||
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C,
|
||||
MatMulEnv& env, int line) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t cols = A.Cols();
|
||||
const size_t B_rows = B.Rows();
|
||||
|
|
@ -101,6 +102,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
MatPadding::kOdd);
|
||||
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
|
||||
MatPadding::kOdd);
|
||||
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
|
||||
MatPadding::kOdd);
|
||||
for (size_t m = 0; m < A.Rows(); ++m) {
|
||||
|
|
@ -225,8 +227,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
|
||||
MatStorageT<TA> A(GenerateMat<TA>(A_extents, pool));
|
||||
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, pool));
|
||||
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> C_slow("C_slow", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> C("C", C_extents, MatPadding::kOdd);
|
||||
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
|
||||
MatStorageT<float> add_storage =
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
||||
|
|
@ -238,7 +241,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
// A few reps to get coverage of the various autotuned code paths.
|
||||
for (size_t rep = 0; rep < 16; ++rep) {
|
||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
|
||||
AssertClose(A, BT, C_slow, C, line);
|
||||
AssertClose(A, BT, C_slow, C, env, line);
|
||||
if (per_key->autotune.Best()) break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
void TestQuestion(const char* question, const char* expected_substring);
|
||||
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
std::vector<uint8_t*> image_row_ptrs_;
|
||||
};
|
||||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
|
|
@ -54,11 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
|||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
MatPadding::kPacked);
|
||||
image_row_ptrs_.resize(image_tokens_->Rows());
|
||||
for (size_t r = 0; r < image_tokens_->Rows(); ++r) {
|
||||
image_row_ptrs_[r] = image_tokens_->RowBytes(r);
|
||||
}
|
||||
image_tokens_->AttachRowPtrs(image_row_ptrs_.data());
|
||||
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
||||
Image image;
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
|
|
|
|||
Loading…
Reference in New Issue