1.16x decode speedup: remove last MatVec in Attention

Precompute row pointers.
Remove no longer used MHA support; QStride -> qkv_dim.
Remove RowPtr from MatMul interface, use only MatPtrT.
Require opt-in define for NUQ to speed up builds.
Also fix io.cc on Windows.

PiperOrigin-RevId: 766228108
This commit is contained in:
Jan Wassenberg 2025-06-02 09:39:57 -07:00 committed by Copybara-Service
parent c4a75abe43
commit cf4d7ceb82
16 changed files with 221 additions and 211 deletions

View File

@ -517,6 +517,7 @@ HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(NuqTest);
#if GEMMA_ENABLE_NUQ
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
@ -530,6 +531,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NuqTest);
#endif // GEMMA_ENABLE_NUQ
HWY_AFTER_TEST();
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -70,12 +70,6 @@ class CompressionTest(absltest.TestCase):
info_256.name = "ignored_256"
info_256.axes = [0]
info_256.shape = [256]
writer.insert(
"tensor_nuq",
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
configs.Type.kNUQ,
info_256,
)
writer.insert(
"tensor_sfp",
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
@ -97,7 +91,7 @@ class CompressionTest(absltest.TestCase):
config = configs.ModelConfig(
configs.Model.GEMMA_TINY,
configs.Type.kNUQ,
configs.Type.kSFP,
configs.PromptWrapping.GEMMA_IT,
)
tokenizer_path = "" # no tokenizer required for testing
@ -108,7 +102,7 @@ class CompressionTest(absltest.TestCase):
reader = compression.SbsReader(temp_file.full_path)
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
self.assertEqual(reader.config.weight, configs.Type.kNUQ)
self.assertEqual(reader.config.weight, configs.Type.kSFP)
mat = reader.find_mat("tensor0")
self.assertEqual(mat.cols, 192)
@ -128,12 +122,6 @@ class CompressionTest(absltest.TestCase):
self.assertEqual(mat.type, configs.Type.kSFP)
self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2)
mat = reader.find_mat("tensor_nuq")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)
self.assertEqual(mat.type, configs.Type.kNUQ)
self.assertAlmostEqual(mat.scale, 1.0)
mat = reader.find_mat("tensor_sfp")
self.assertEqual(mat.cols, 256)
self.assertEqual(mat.rows, 1)

View File

@ -62,8 +62,10 @@ void ForeachPackedAndRawType() {
ForeachRawType<BF16, TestT>();
ForeachRawType<float, TestT>();
ForeachRawType<SfpStream, TestT>();
if constexpr (GEMMA_ENABLE_NUQ) {
ForeachRawType<NuqStream, TestT>();
}
}
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>

View File

@ -29,6 +29,11 @@
namespace gcpp {
// Only used in experiments, hence disable in default builds.
#ifndef GEMMA_ENABLE_NUQ
#define GEMMA_ENABLE_NUQ 0
#endif
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1 and decoding to bf16/f32.

View File

@ -38,8 +38,7 @@ struct Activations {
is_griffin(config.model == Model::GRIFFIN_2B),
x("x", Extents2D(batch_size, config.model_dim), pad_),
q("q",
Extents2D(batch_size, layer_config.heads * layer_config.QStride()),
q("q", Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
@ -82,6 +81,25 @@ struct Activations {
env(env) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
uint8_t** ptrs = row_ptrs.back().get();
for (size_t r = 0; r < mat.Rows(); ++r) {
ptrs[r] = mat.RowBytes(r);
}
mat.AttachRowPtrs(ptrs);
};
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
init_row_ptrs(q);
init_row_ptrs(logits);
init_row_ptrs(att_sums);
init_row_ptrs(C1);
init_row_ptrs(C2);
init_row_ptrs(ffw_out);
// TODO: also init rows for image_tokens.
// Note that BindC on any MatMul output considerably slows down Prefill.
}
@ -144,6 +162,9 @@ struct Activations {
MatStorageT<float> inv_timescale_global;
MatMulEnv* env;
// Per-tensor allocations to make it likelier that asan detects bugs such as
// use after free, overrun, and dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};
} // namespace gcpp

View File

@ -259,16 +259,12 @@ struct LayerConfig : public IFields {
// Multi-Head Attention?
bool IsMHA() const { return heads == kv_heads; }
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
uint32_t model_dim = 0;
uint32_t griffin_dim = 0;
uint32_t ff_hidden_dim = 0;
uint32_t heads = 0;
uint32_t kv_heads = 0;
uint32_t qkv_dim = 0;
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
uint32_t conv1d_width = 0; // Griffin only
bool ff_biases = false;
bool softmax_attn_output_biases = false; // for Griffin

View File

@ -249,64 +249,38 @@ class GemmaAttention {
}
}
// Fills activations.q and computes KV. For is_mha_, a single MatMul suffices
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
// KV directly to KVCache.
// Fills activations.q and writes to KV cache.
HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.QKV");
const size_t model_dim = layer_config_.model_dim;
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
const size_t kv_heads = layer_config_.kv_heads;
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
// We must shrink to the actual size because MatMul verifies
// `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
// computed in the second MatMul.
const size_t w1_rows = heads * layer_config_.QStride();
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env,
RowPtrFromMat(activations_.q));
/*add=*/nullptr, *activations_.env, activations_.q);
if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else {
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols);
// Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of cache_pos_size_.
if (num_queries_ == 1 &&
queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
const size_t kv_ofs =
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
} else {
// Proceed row by row because there will be wraparound.
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len.
MatPtrT<float> kv_rows("kv",
Extents2D(activations_.pre_att_rms_out.Rows(),
layer_weights_.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const float* x = activations_.pre_att_rms_out.Row(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
KVCache& kv_cache = kv_caches_[query_idx];
const size_t cache_pos =
div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx);
const size_t kv_offset =
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x,
kv, pool_);
activations_.env->storage.OutRow(interleaved_idx) =
reinterpret_cast<uint8_t*>(kv_caches_[query_idx].kv_cache.get() +
kv_offset);
}
}
} // !is_mha_
kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0));
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, kv_heads * num_interleaved,
@ -322,13 +296,6 @@ class GemmaAttention {
head * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// If MHA, copy computed K and V into KVCache.
if (is_mha_) {
const float* HWY_RESTRICT mha_kv =
activations_.q.Row(interleaved_idx) + head * q_stride_ +
qkv_dim;
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
}
// Apply further processing to K.
if (layer_weights_.key_norm_scale.HasPtr()) {
@ -435,7 +402,7 @@ class GemmaAttention {
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q =
activations_.q.Row(interleaved_idx) + head * q_stride_;
activations_.q.Row(interleaved_idx) + head * qkv_dim;
float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) +
head * activations_.seq_len;
float* HWY_RESTRICT att_out =
@ -490,7 +457,7 @@ class GemmaAttention {
? layer_weights_.attention_output_biases.PackedScale1()
: nullptr;
MatMulStatic(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, RowPtrFromMat(activations_.att_sums));
*activations_.env, activations_.att_sums);
}
public:
@ -548,15 +515,14 @@ class GemmaAttention {
num_tokens_(num_tokens),
layer_(layer),
layer_config_(layer_weights->layer_config),
q_stride_(layer_config_.QStride()),
cache_layer_size_(layer_weights->layer_config.CacheLayerSize()),
cache_pos_size_(activations.cache_pos_size),
is_mha_(layer_config_.IsMHA()),
activations_(activations),
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(ctx.pools.Pool(0)) {
HWY_DASSERT(!layer_config_.IsMHA()); // No longer supported.
HWY_DASSERT(num_queries_ <= kv_caches_.size());
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
@ -576,10 +542,8 @@ class GemmaAttention {
const size_t num_tokens_;
const size_t layer_;
const LayerConfig& layer_config_;
const size_t q_stride_ = 0;
const size_t cache_layer_size_ = 0;
const size_t cache_pos_size_ = 0;
const bool is_mha_ = false;
Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_;
@ -627,7 +591,7 @@ class VitAttention {
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
*activations_.env, RowPtrFromMat(qkv));
*activations_.env, qkv);
}
// TODO(philculliton): transition fully to MatMul.
@ -667,7 +631,7 @@ class VitAttention {
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMulStatic(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
MatMulStatic(Q, K, nullptr, *activations_.env, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task);
@ -733,9 +697,8 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
auto att_sums = RowPtrFromMat(activations_.att_sums);
MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, att_sums);
*activations_.env, activations_.att_sums);
}
public:
@ -827,9 +790,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
// Compute the hidden layer activations.
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
bias1, *activations.env, RowPtrFromMat(activations.C1));
bias1, *activations.env, activations.C1);
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
bias2, *activations.env, RowPtrFromMat(activations.C2));
bias2, *activations.env, activations.C2);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
@ -837,7 +800,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
// Hidden layer -> output layer.
MatMulStatic(activations.C1, layer_weights->linear_w, output_bias,
*activations.env, RowPtrFromMat(activations.ffw_out));
*activations.env, activations.ffw_out);
}
// Same as FFWNoVit, but with different layer_weights members and no second
@ -855,14 +818,14 @@ HWY_NOINLINE void FFWVit(Activations& activations,
// Compute the hidden layer activations.
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w,
bias1, *activations.env, RowPtrFromMat(activations.C1));
bias1, *activations.env, activations.C1);
// Activation (Gelu), store in C1.
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
// Hidden layer -> output layer.
MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, RowPtrFromMat(activations.ffw_out));
*activations.env, activations.ffw_out);
}
// `batch_idx` indicates which row of `x` to write to.
@ -1176,10 +1139,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// kPatchSize), MatPadding::kPacked);
// [Get patches]
// MatMulStatic(
// MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel),
// image_patches,
// weights.vit_img_embedding_kernel,
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
// RowPtrF(activations.x.Row(0), kVitModelDim));
// activations.x);
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and
@ -1228,7 +1191,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMulStatic(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromMat(image_tokens));
image_tokens);
}
// Generates one token for each query. `queries_token` is the previous token
@ -1367,8 +1330,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMulStatic(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env,
RowPtrFromMat(activations.logits));
/*add=*/nullptr, *activations.env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {

View File

@ -393,14 +393,7 @@ struct LayerWeightsPtrs {
// MHA, and otherwise might not be the same type.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
const size_t w1_rows = layer_config.heads * layer_config.QStride();
if (layer_config.IsMHA()) { // MHA only requires w1.
qkv_einsum_w1 = qkv_einsum_w;
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
return;
}
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);

View File

@ -15,10 +15,6 @@
// Safe to be first, does not include POSIX headers.
#include "hwy/detect_compiler_arch.h"
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
// check this in source code because we support multiple build systems.
#if !HWY_OS_WIN
// Request POSIX 2008, including `pread()` and `posix_fadvise()`. This also
// implies `_POSIX_C_SOURCE`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
@ -30,6 +26,14 @@
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#include <stddef.h>
#include <memory>
#include <string>
#include "io/io.h"
#include "hwy/base.h" // HWY_ASSERT
#if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \
(!defined(__ANDROID_API__) || __ANDROID_API__ >= 24)
#define GEMMA_IO_PREADV 1
@ -44,6 +48,11 @@
#define GEMMA_IO_FADVISE 0
#endif
// FilePosix should only be compiled on non-Windows. It is easier to
// check this in source code because we support multiple build systems. Note
// that IOBatch at the end of this TU is still compiled on all platforms.
#if !HWY_OS_WIN
#if GEMMA_IO_PREADV
// Replacement for the _BSD_SOURCE specified by preadv documentation.
#ifndef _DEFAULT_SOURCE
@ -55,7 +64,6 @@
#include <fcntl.h> // open
#include <limits.h> // IOV_MAX
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
#include <sys/types.h>
@ -64,12 +72,7 @@
#include <sys/stat.h> // O_RDONLY
#include <unistd.h> // read, write, close
#include <memory>
#include <string>
#include "io/io.h"
#include "util/allocator.h"
#include "hwy/base.h" // HWY_ASSERT
namespace gcpp {
@ -168,6 +171,12 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
return std::make_unique<FilePosix>(fd);
}
} // namespace gcpp
#endif // !HWY_OS_WIN
namespace gcpp {
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
if (!file) {
@ -237,4 +246,3 @@ uint64_t IOBatch::Read(const File& file) const {
}
} // namespace gcpp
#endif // !HWY_OS_WIN

View File

@ -91,8 +91,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N);
MatStorageT<TC> c_slow_mat("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> c_mat("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
if (add) {
@ -104,7 +104,6 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C = RowPtrFromMat(c_mat);
// Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12;
@ -115,7 +114,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(b_trans, sizeof(TC), env.parallel);
BindC(c_mat, env.parallel);
BindC(C, env.parallel);
Tristate use_spinning = Tristate::kDefault;
env.ctx.pools.MaybeStartSpinning(use_spinning);

View File

@ -80,6 +80,7 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
return hn::DemoteTo(dc, vf);
}
// Type-safe wrapper over uint8_t row pointers referenced by MatPtrT.
template <typename TC>
class CRows {
public:
@ -1183,7 +1184,10 @@ class MMPerPackage {
if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if no zero-padding required.
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(A);
if (HWY_LIKELY(A.Cols() % NBF == 0)) {
// Actually const, but RowPtr is also used for partial which is not.
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
}
}
if (HWY_LIKELY(autotune.Best())) {
@ -1312,7 +1316,21 @@ struct MMImpl {
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
CRows<TC> C_rows) {
MatPtrT<TC>& C) {
CRows<TC> C_rows(C.GetRowPtrs());
if (HWY_UNLIKELY(!C.GetRowPtrs())) {
if constexpr (HWY_IS_DEBUG_BUILD) {
fprintf(stderr,
"MatMul perf warning: setting row pointers because "
"C.AttachRowPtrs() was not called.\n");
}
HWY_DASSERT(C.HasPtr());
for (size_t r = 0; r < C.Rows(); ++r) {
env.storage.OutRow(r) = reinterpret_cast<uint8_t*>(C.Row(r));
}
C_rows = CRows<TC>(&env.storage.OutRow(0));
}
const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Rows();
const size_t K = A.Cols();
@ -1392,19 +1410,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
return &per_key;
}
// Adapter that fills the row array. This is the common case, whereas only
// GemmaAttention::ComputeQKV uses the arbitrary output rows feature.
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
HWY_DASSERT(B.Rows() == C.Cols());
for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) {
env.storage.OutRow(row_ac) = reinterpret_cast<uint8_t*>(C.Row(row_ac));
}
return MatMul(A, B, add, env, CRows<TC>(&env.storage.OutRow(0)));
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -176,6 +176,44 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
// C is BF16/float, or double for partial.
void BindC(MatPtr& C, MMParallel& parallel);
// Lightweight view into `MatStorageT`.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
public:
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
HWY_DASSERT(stride >= cols);
}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return RowPtr<T>(Row(r) + c, cols, stride_);
}
private:
T* HWY_RESTRICT row0_;
uint32_t cols_;
uint32_t stride_;
};
#pragma pack(pop)
using RowPtrBF = RowPtr<BF16>;
using RowPtrD = RowPtr<double>;
// Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K).
class MMStorage {

View File

@ -28,7 +28,7 @@
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
const RowPtr<TC>& C) { \
MatPtrT<TC>& C) { \
return MatMul(A, B, add, env, C); \
}

View File

@ -35,7 +35,7 @@
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
const float* HWY_RESTRICT add, MatMulEnv& env, \
const RowPtr<TC>& C);
MatPtrT<TC>& C);
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \

View File

@ -91,7 +91,7 @@ float MaxAbs(const MatStorageT<float>& a) {
// B is already transposed.
template <typename TA, typename TB, typename TC>
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C, int line) {
const hn::ScalableTag<float> df;
const size_t cols = A.Cols();
const size_t B_rows = B.Rows();
@ -161,7 +161,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const float* HWY_RESTRICT add_row, MatMulEnv& env,
const RowPtr<TC>& C) {
MatPtrT<TC>& C) {
// TA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs.
@ -223,25 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc);
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<TA> A(GenerateMat<TA>(A_extents, pool));
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, pool));
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
add_storage.SetScale(1.0f);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromMat(c_slow_batch);
const RowPtr<TC> C = RowPtrFromMat(c_batch);
MatMulSlow(a, b_trans, add_row, env, C_slow);
MatMulSlow(A, BT, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMulStatic(a, b_trans, add_row, env, C);
AssertClose(a, b_trans, C_slow, C, line);
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
AssertClose(A, BT, C_slow, C, line);
if (per_key->autotune.Best()) break;
}
}

View File

@ -33,6 +33,18 @@
namespace gcpp {
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr.
template <typename TC>
class CRows {
public:
CRows(TC** C_rows) : C_rows_(C_rows) {}
TC* HWY_RESTRICT operator[](size_t row_idx) const { return C_rows_[row_idx]; }
private:
TC** C_rows_;
};
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
// to store hetereogeneous tensor references in a vector.
@ -63,13 +75,29 @@ class MatPtr : public IFields {
ptr_ = ptr;
stride_ = static_cast<uint32_t>(stride);
// If row pointers were already attached, `SetPtr` would invalidate them.
HWY_DASSERT_M(row_ptrs_ == nullptr, "Do not call after AttachRowPtrs.");
// NUQ streams must not be padded because that would change the position of
// the group tables.
if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked());
if (type_ == Type::kNUQ) {
HWY_ASSERT_M(GEMMA_ENABLE_NUQ, "Set GEMMA_ENABLE_NUQ=1.");
HWY_ASSERT(IsPacked());
}
}
bool HasPtr() const { return ptr_ != nullptr; }
// Caller has initialized Rows() pointers in row_ptrs[].
void AttachRowPtrs(uint8_t** row_ptrs) {
row_ptrs_ = row_ptrs;
for (size_t r = 0; r < Rows(); ++r) {
HWY_DASSERT(row_ptrs[r] != nullptr);
}
}
uint8_t** GetRowPtrs() const { return row_ptrs_; }
// A single row counts as packed because there is no padding between rows.
bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); }
@ -195,6 +223,11 @@ class MatPtr : public IFields {
// this object.
void* ptr_ = nullptr; // not serialized
// Points to an array of pointers, one per row, or nullptr if `AttachRowPtrs`
// was not called. Only used for MatMul output tensors, hence we
// minimize the cost for other tensors by only holding a non-owning pointer.
uint8_t** row_ptrs_ = nullptr; // not serialized
// Offset by which to advance pointers to the next row, >= `cols_`.
uint32_t stride_;
@ -261,6 +294,13 @@ class MatPtrT : public MatPtr {
template <class Func, typename... Args>
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) {
#if GEMMA_ENABLE_NUQ
if (base->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
std::forward<Args>(args)...);
}
#endif // GEMMA_ENABLE_NUQ
if (base->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base),
std::forward<Args>(args)...);
@ -270,9 +310,6 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
} else if (base->GetType() == Type::kSFP) {
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
std::forward<Args>(args)...);
} else if (base->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
@ -283,6 +320,15 @@ template <class Func, typename... Args>
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const Func& func, Args&&... args) {
HWY_ASSERT(base1->GetType() == base2->GetType());
#if GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
std::forward<Args>(args)...);
}
#endif // GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base1),
dynamic_cast<const MatPtrT<float>*>(base2),
@ -295,10 +341,6 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
}
@ -384,55 +426,5 @@ class MatStorageT : public MatPtrT<MatT> {
MatOwner owner_;
};
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
// seekable (non-NUQ) T.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
public:
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
HWY_DASSERT(stride >= cols);
}
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return RowPtr<T>(Row(r) + c, cols, stride_);
}
private:
T* HWY_RESTRICT row0_;
uint32_t cols_;
uint32_t stride_;
};
#pragma pack(pop)
using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>;
template <typename T>
RowPtr<T> RowPtrFromMat(const MatPtrT<T>& row_vectors) {
// RowPtr is non-const for MatMul C, but is also used for A which is const.
// Callers are responsible for checking their usage of RowPtr.
return RowPtr<T>(const_cast<T*>(row_vectors.Row(0)), row_vectors.Cols(),
row_vectors.Stride());
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_