mirror of https://github.com/google/gemma.cpp.git
1.16x decode speedup: remove last MatVec in Attention
Precompute row pointers. Remove no longer used MHA support; QStride -> qkv_dim. Remove RowPtr from MatMul interface, use only MatPtrT. Require opt-in define for NUQ to speed up builds. Also fix io.cc on Windows. PiperOrigin-RevId: 766228108
This commit is contained in:
parent
c4a75abe43
commit
cf4d7ceb82
|
|
@ -517,6 +517,7 @@ HWY_AFTER_NAMESPACE();
|
|||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(NuqTest);
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
||||
|
|
@ -530,6 +531,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32);
|
|||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
|
||||
#else
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NuqTest);
|
||||
#endif // GEMMA_ENABLE_NUQ
|
||||
HWY_AFTER_TEST();
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -70,12 +70,6 @@ class CompressionTest(absltest.TestCase):
|
|||
info_256.name = "ignored_256"
|
||||
info_256.axes = [0]
|
||||
info_256.shape = [256]
|
||||
writer.insert(
|
||||
"tensor_nuq",
|
||||
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
|
||||
configs.Type.kNUQ,
|
||||
info_256,
|
||||
)
|
||||
writer.insert(
|
||||
"tensor_sfp",
|
||||
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
|
||||
|
|
@ -97,7 +91,7 @@ class CompressionTest(absltest.TestCase):
|
|||
|
||||
config = configs.ModelConfig(
|
||||
configs.Model.GEMMA_TINY,
|
||||
configs.Type.kNUQ,
|
||||
configs.Type.kSFP,
|
||||
configs.PromptWrapping.GEMMA_IT,
|
||||
)
|
||||
tokenizer_path = "" # no tokenizer required for testing
|
||||
|
|
@ -108,7 +102,7 @@ class CompressionTest(absltest.TestCase):
|
|||
reader = compression.SbsReader(temp_file.full_path)
|
||||
|
||||
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
|
||||
self.assertEqual(reader.config.weight, configs.Type.kNUQ)
|
||||
self.assertEqual(reader.config.weight, configs.Type.kSFP)
|
||||
|
||||
mat = reader.find_mat("tensor0")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
|
|
@ -128,12 +122,6 @@ class CompressionTest(absltest.TestCase):
|
|||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2)
|
||||
|
||||
mat = reader.find_mat("tensor_nuq")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kNUQ)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_sfp")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,9 @@ void ForeachPackedAndRawType() {
|
|||
ForeachRawType<BF16, TestT>();
|
||||
ForeachRawType<float, TestT>();
|
||||
ForeachRawType<SfpStream, TestT>();
|
||||
ForeachRawType<NuqStream, TestT>();
|
||||
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||
ForeachRawType<NuqStream, TestT>();
|
||||
}
|
||||
}
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
|
|
|
|||
|
|
@ -29,6 +29,11 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Only used in experiments, hence disable in default builds.
|
||||
#ifndef GEMMA_ENABLE_NUQ
|
||||
#define GEMMA_ENABLE_NUQ 0
|
||||
#endif
|
||||
|
||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
||||
|
|
|
|||
|
|
@ -38,8 +38,7 @@ struct Activations {
|
|||
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
q("q",
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.QStride()),
|
||||
q("q", Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
||||
pad_),
|
||||
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
|
||||
|
||||
|
|
@ -82,6 +81,25 @@ struct Activations {
|
|||
env(env) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
|
||||
uint8_t** ptrs = row_ptrs.back().get();
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
ptrs[r] = mat.RowBytes(r);
|
||||
}
|
||||
mat.AttachRowPtrs(ptrs);
|
||||
};
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
init_row_ptrs(q);
|
||||
init_row_ptrs(logits);
|
||||
init_row_ptrs(att_sums);
|
||||
init_row_ptrs(C1);
|
||||
init_row_ptrs(C2);
|
||||
init_row_ptrs(ffw_out);
|
||||
// TODO: also init rows for image_tokens.
|
||||
|
||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
||||
|
|
@ -144,6 +162,9 @@ struct Activations {
|
|||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
MatMulEnv* env;
|
||||
// Per-tensor allocations to make it likelier that asan detects bugs such as
|
||||
// use after free, overrun, and dangling references.
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -259,16 +259,12 @@ struct LayerConfig : public IFields {
|
|||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t griffin_dim = 0;
|
||||
uint32_t ff_hidden_dim = 0;
|
||||
uint32_t heads = 0;
|
||||
uint32_t kv_heads = 0;
|
||||
uint32_t qkv_dim = 0;
|
||||
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
|
||||
uint32_t conv1d_width = 0; // Griffin only
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false; // for Griffin
|
||||
|
|
|
|||
|
|
@ -249,64 +249,38 @@ class GemmaAttention {
|
|||
}
|
||||
}
|
||||
|
||||
// Fills activations.q and computes KV. For is_mha_, a single MatMul suffices
|
||||
// and we later copy KV from q to KVCache. Otherwise, a second MatMul writes
|
||||
// KV directly to KVCache.
|
||||
// Fills activations.q and writes to KV cache.
|
||||
HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
const size_t model_dim = layer_config_.model_dim;
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
const size_t kv_heads = layer_config_.kv_heads;
|
||||
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
|
||||
// We must shrink to the actual size because MatMul verifies
|
||||
// `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all
|
||||
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
|
||||
// computed in the second MatMul.
|
||||
const size_t w1_rows = heads * layer_config_.QStride();
|
||||
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
|
||||
/*add=*/nullptr, *activations_.env,
|
||||
RowPtrFromMat(activations_.q));
|
||||
/*add=*/nullptr, *activations_.env, activations_.q);
|
||||
|
||||
if (is_mha_) {
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||
} else {
|
||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||
HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols);
|
||||
|
||||
// Single query and no wraparound means we can use a matmul and write
|
||||
// directly into the KV cache with a stride of cache_pos_size_.
|
||||
if (num_queries_ == 1 &&
|
||||
queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) {
|
||||
const size_t kv_ofs =
|
||||
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||
kv_rows.SetStride(cache_pos_size_);
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||
} else {
|
||||
// Proceed row by row because there will be wraparound.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const float* x = activations_.pre_att_rms_out.Row(interleaved_idx);
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
const size_t cache_pos =
|
||||
div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx);
|
||||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x,
|
||||
kv, pool_);
|
||||
}
|
||||
}
|
||||
} // !is_mha_
|
||||
// Set up MatMul row pointers for writing to KV, which consists of
|
||||
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
||||
// because rows are computed modulo seq_len.
|
||||
MatPtrT<float> kv_rows("kv",
|
||||
Extents2D(activations_.pre_att_rms_out.Rows(),
|
||||
layer_weights_.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = interleaved_idx % num_queries_;
|
||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||
const size_t cache_pos =
|
||||
div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx);
|
||||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
activations_.env->storage.OutRow(interleaved_idx) =
|
||||
reinterpret_cast<uint8_t*>(kv_caches_[query_idx].kv_cache.get() +
|
||||
kv_offset);
|
||||
}
|
||||
kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0));
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||
|
||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||
pool_.Run(0, kv_heads * num_interleaved,
|
||||
|
|
@ -322,13 +296,6 @@ class GemmaAttention {
|
|||
head * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// If MHA, copy computed K and V into KVCache.
|
||||
if (is_mha_) {
|
||||
const float* HWY_RESTRICT mha_kv =
|
||||
activations_.q.Row(interleaved_idx) + head * q_stride_ +
|
||||
qkv_dim;
|
||||
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
|
||||
}
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer_weights_.key_norm_scale.HasPtr()) {
|
||||
|
|
@ -435,7 +402,7 @@ class GemmaAttention {
|
|||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Row(interleaved_idx) + head * q_stride_;
|
||||
activations_.q.Row(interleaved_idx) + head * qkv_dim;
|
||||
float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) +
|
||||
head * activations_.seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
|
|
@ -490,7 +457,7 @@ class GemmaAttention {
|
|||
? layer_weights_.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
MatMulStatic(activations_.att_out, layer_weights_.att_weights, add,
|
||||
*activations_.env, RowPtrFromMat(activations_.att_sums));
|
||||
*activations_.env, activations_.att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -548,15 +515,14 @@ class GemmaAttention {
|
|||
num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
q_stride_(layer_config_.QStride()),
|
||||
cache_layer_size_(layer_weights->layer_config.CacheLayerSize()),
|
||||
cache_pos_size_(activations.cache_pos_size),
|
||||
is_mha_(layer_config_.IsMHA()),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(ctx.pools.Pool(0)) {
|
||||
HWY_DASSERT(!layer_config_.IsMHA()); // No longer supported.
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
|
|
@ -576,10 +542,8 @@ class GemmaAttention {
|
|||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
const LayerConfig& layer_config_;
|
||||
const size_t q_stride_ = 0;
|
||||
const size_t cache_layer_size_ = 0;
|
||||
const size_t cache_pos_size_ = 0;
|
||||
const bool is_mha_ = false;
|
||||
|
||||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
|
|
@ -627,7 +591,7 @@ class VitAttention {
|
|||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
|
||||
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
|
||||
*activations_.env, RowPtrFromMat(qkv));
|
||||
*activations_.env, qkv);
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
|
|
@ -667,7 +631,7 @@ class VitAttention {
|
|||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
MatMulStatic(Q, K, nullptr, *activations_.env, RowPtrFromMat(C));
|
||||
MatMulStatic(Q, K, nullptr, *activations_.env, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
|
|
@ -733,9 +697,8 @@ class VitAttention {
|
|||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
auto att_sums = RowPtrFromMat(activations_.att_sums);
|
||||
MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
|
||||
*activations_.env, att_sums);
|
||||
*activations_.env, activations_.att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -827,9 +790,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
|||
|
||||
// Compute the hidden layer activations.
|
||||
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
|
||||
bias1, *activations.env, RowPtrFromMat(activations.C1));
|
||||
bias1, *activations.env, activations.C1);
|
||||
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
|
||||
bias2, *activations.env, RowPtrFromMat(activations.C2));
|
||||
bias2, *activations.env, activations.C2);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
|
||||
|
|
@ -837,7 +800,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
|||
|
||||
// Hidden layer -> output layer.
|
||||
MatMulStatic(activations.C1, layer_weights->linear_w, output_bias,
|
||||
*activations.env, RowPtrFromMat(activations.ffw_out));
|
||||
*activations.env, activations.ffw_out);
|
||||
}
|
||||
|
||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||
|
|
@ -855,14 +818,14 @@ HWY_NOINLINE void FFWVit(Activations& activations,
|
|||
|
||||
// Compute the hidden layer activations.
|
||||
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w,
|
||||
bias1, *activations.env, RowPtrFromMat(activations.C1));
|
||||
bias1, *activations.env, activations.C1);
|
||||
|
||||
// Activation (Gelu), store in C1.
|
||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias,
|
||||
*activations.env, RowPtrFromMat(activations.ffw_out));
|
||||
*activations.env, activations.ffw_out);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -1176,10 +1139,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// kPatchSize), MatPadding::kPacked);
|
||||
// [Get patches]
|
||||
// MatMulStatic(
|
||||
// MatFromBatch(kVitSeqLen, image_patches),
|
||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||
// image_patches,
|
||||
// weights.vit_img_embedding_kernel,
|
||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
// RowPtrF(activations.x.Row(0), kVitModelDim));
|
||||
// activations.x);
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
|
|
@ -1228,7 +1191,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMulStatic(activations.x, weights.vit_img_head_kernel,
|
||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||
RowPtrFromMat(image_tokens));
|
||||
image_tokens);
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
|
|
@ -1367,8 +1330,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
|||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMulStatic(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, *activations.env,
|
||||
RowPtrFromMat(activations.logits));
|
||||
/*add=*/nullptr, *activations.env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
|
|||
|
|
@ -393,14 +393,7 @@ struct LayerWeightsPtrs {
|
|||
// MHA, and otherwise might not be the same type.
|
||||
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t w1_rows = layer_config.heads * layer_config.QStride();
|
||||
|
||||
if (layer_config.IsMHA()) { // MHA only requires w1.
|
||||
qkv_einsum_w1 = qkv_einsum_w;
|
||||
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
|
||||
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
|
||||
|
||||
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
|
||||
|
|
|
|||
30
io/io.cc
30
io/io.cc
|
|
@ -15,10 +15,6 @@
|
|||
|
||||
// Safe to be first, does not include POSIX headers.
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
|
||||
// check this in source code because we support multiple build systems.
|
||||
#if !HWY_OS_WIN
|
||||
|
||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`. This also
|
||||
// implies `_POSIX_C_SOURCE`.
|
||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||
|
|
@ -30,6 +26,14 @@
|
|||
#undef _FILE_OFFSET_BITS
|
||||
#define _FILE_OFFSET_BITS 64
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "io/io.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
#if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \
|
||||
(!defined(__ANDROID_API__) || __ANDROID_API__ >= 24)
|
||||
#define GEMMA_IO_PREADV 1
|
||||
|
|
@ -44,6 +48,11 @@
|
|||
#define GEMMA_IO_FADVISE 0
|
||||
#endif
|
||||
|
||||
// FilePosix should only be compiled on non-Windows. It is easier to
|
||||
// check this in source code because we support multiple build systems. Note
|
||||
// that IOBatch at the end of this TU is still compiled on all platforms.
|
||||
#if !HWY_OS_WIN
|
||||
|
||||
#if GEMMA_IO_PREADV
|
||||
// Replacement for the _BSD_SOURCE specified by preadv documentation.
|
||||
#ifndef _DEFAULT_SOURCE
|
||||
|
|
@ -55,7 +64,6 @@
|
|||
|
||||
#include <fcntl.h> // open
|
||||
#include <limits.h> // IOV_MAX
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/types.h>
|
||||
|
|
@ -64,12 +72,7 @@
|
|||
#include <sys/stat.h> // O_RDONLY
|
||||
#include <unistd.h> // read, write, close
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "io/io.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -168,6 +171,12 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
|||
return std::make_unique<FilePosix>(fd);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // !HWY_OS_WIN
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
|
||||
if (!file) {
|
||||
|
|
@ -237,4 +246,3 @@ uint64_t IOBatch::Read(const File& file) const {
|
|||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // !HWY_OS_WIN
|
||||
|
|
|
|||
|
|
@ -91,8 +91,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
const Extents2D B_extents(N, K); // already transposed
|
||||
const Extents2D C_extents(M, N);
|
||||
|
||||
MatStorageT<TC> c_slow_mat("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_mat("c_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
|
||||
|
||||
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
|
||||
if (add) {
|
||||
|
|
@ -104,7 +104,6 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C = RowPtrFromMat(c_mat);
|
||||
|
||||
// Fewer reps for large batch sizes, which take longer.
|
||||
const size_t num_samples = M < 32 ? 20 : 12;
|
||||
|
|
@ -115,7 +114,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
// spinning may materially affect the choice of config. No harm in calling
|
||||
// BindB/C if there is a single package: they will be a no-op.
|
||||
BindB(b_trans, sizeof(TC), env.parallel);
|
||||
BindC(c_mat, env.parallel);
|
||||
BindC(C, env.parallel);
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
env.ctx.pools.MaybeStartSpinning(use_spinning);
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
|
|||
return hn::DemoteTo(dc, vf);
|
||||
}
|
||||
|
||||
// Type-safe wrapper over uint8_t row pointers referenced by MatPtrT.
|
||||
template <typename TC>
|
||||
class CRows {
|
||||
public:
|
||||
|
|
@ -1183,7 +1184,10 @@ class MMPerPackage {
|
|||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
// Only if no zero-padding required.
|
||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(A);
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0)) {
|
||||
// Actually const, but RowPtr is also used for partial which is not.
|
||||
return RowPtrBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
||||
}
|
||||
}
|
||||
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
|
|
@ -1312,7 +1316,21 @@ struct MMImpl {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
CRows<TC> C_rows) {
|
||||
MatPtrT<TC>& C) {
|
||||
CRows<TC> C_rows(C.GetRowPtrs());
|
||||
if (HWY_UNLIKELY(!C.GetRowPtrs())) {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
fprintf(stderr,
|
||||
"MatMul perf warning: setting row pointers because "
|
||||
"C.AttachRowPtrs() was not called.\n");
|
||||
}
|
||||
HWY_DASSERT(C.HasPtr());
|
||||
for (size_t r = 0; r < C.Rows(); ++r) {
|
||||
env.storage.OutRow(r) = reinterpret_cast<uint8_t*>(C.Row(r));
|
||||
}
|
||||
C_rows = CRows<TC>(&env.storage.OutRow(0));
|
||||
}
|
||||
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
const size_t M = A.Rows();
|
||||
const size_t K = A.Cols();
|
||||
|
|
@ -1392,19 +1410,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
return &per_key;
|
||||
}
|
||||
|
||||
// Adapter that fills the row array. This is the common case, whereas only
|
||||
// GemmaAttention::ComputeQKV uses the arbitrary output rows feature.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
HWY_DASSERT(B.Rows() == C.Cols());
|
||||
for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) {
|
||||
env.storage.OutRow(row_ac) = reinterpret_cast<uint8_t*>(C.Row(row_ac));
|
||||
}
|
||||
return MatMul(A, B, add, env, CRows<TC>(&env.storage.OutRow(0)));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
38
ops/matmul.h
38
ops/matmul.h
|
|
@ -176,6 +176,44 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
|||
// C is BF16/float, or double for partial.
|
||||
void BindC(MatPtr& C, MMParallel& parallel);
|
||||
|
||||
// Lightweight view into `MatStorageT`.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
: row0_(row0),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
stride_(static_cast<uint32_t>(stride)) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||
|
||||
size_t Stride() const { return static_cast<size_t>(stride_); }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||
HWY_DASSERT(c < Cols());
|
||||
HWY_DASSERT(cols <= Cols() - c);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_);
|
||||
}
|
||||
|
||||
private:
|
||||
T* HWY_RESTRICT row0_;
|
||||
uint32_t cols_;
|
||||
uint32_t stride_;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
using RowPtrBF = RowPtr<BF16>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
|
||||
// Per-package storage for packed A, and one global C-shaped `partial` for
|
||||
// accumulating partial dot products (sections of K).
|
||||
class MMStorage {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
#define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \
|
||||
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||
const RowPtr<TC>& C) { \
|
||||
MatPtrT<TC>& C) { \
|
||||
return MatMul(A, B, add, env, C); \
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
#define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \
|
||||
MMPerKey* MatMulStatic(const MatPtrT<TA>& A, const MatPtrT<TB>& B, \
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||
const RowPtr<TC>& C);
|
||||
MatPtrT<TC>& C);
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
||||
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ float MaxAbs(const MatStorageT<float>& a) {
|
|||
// B is already transposed.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
||||
const MatPtrT<TC>& C_slow, const MatPtrT<TC>& C, int line) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t cols = A.Cols();
|
||||
const size_t B_rows = B.Rows();
|
||||
|
|
@ -161,7 +161,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
|
||||
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
MatPtrT<TC>& C) {
|
||||
// TA can be any Packed except NuqStream because it uses pointer
|
||||
// arithmetic, because it is the second argument to Dot, which does not
|
||||
// support a v_ofs.
|
||||
|
|
@ -223,25 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||
const Extents2D C_extents(rows_ac, cols_bc);
|
||||
|
||||
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
|
||||
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
|
||||
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TA> A(GenerateMat<TA>(A_extents, pool));
|
||||
MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, pool));
|
||||
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd);
|
||||
|
||||
MatStorageT<float> add_storage =
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
||||
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
|
||||
add_storage.SetScale(1.0f);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C_slow = RowPtrFromMat(c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(c_batch);
|
||||
|
||||
MatMulSlow(a, b_trans, add_row, env, C_slow);
|
||||
MatMulSlow(A, BT, add_row, env, C_slow);
|
||||
// A few reps to get coverage of the various autotuned code paths.
|
||||
for (size_t rep = 0; rep < 16; ++rep) {
|
||||
MMPerKey* per_key = MatMulStatic(a, b_trans, add_row, env, C);
|
||||
AssertClose(a, b_trans, C_slow, C, line);
|
||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C);
|
||||
AssertClose(A, BT, C_slow, C, line);
|
||||
if (per_key->autotune.Best()) break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
108
util/mat.h
108
util/mat.h
|
|
@ -33,6 +33,18 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr.
|
||||
template <typename TC>
|
||||
class CRows {
|
||||
public:
|
||||
CRows(TC** C_rows) : C_rows_(C_rows) {}
|
||||
|
||||
TC* HWY_RESTRICT operator[](size_t row_idx) const { return C_rows_[row_idx]; }
|
||||
|
||||
private:
|
||||
TC** C_rows_;
|
||||
};
|
||||
|
||||
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
||||
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
|
||||
// to store hetereogeneous tensor references in a vector.
|
||||
|
|
@ -63,13 +75,29 @@ class MatPtr : public IFields {
|
|||
ptr_ = ptr;
|
||||
stride_ = static_cast<uint32_t>(stride);
|
||||
|
||||
// If row pointers were already attached, `SetPtr` would invalidate them.
|
||||
HWY_DASSERT_M(row_ptrs_ == nullptr, "Do not call after AttachRowPtrs.");
|
||||
|
||||
// NUQ streams must not be padded because that would change the position of
|
||||
// the group tables.
|
||||
if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked());
|
||||
if (type_ == Type::kNUQ) {
|
||||
HWY_ASSERT_M(GEMMA_ENABLE_NUQ, "Set GEMMA_ENABLE_NUQ=1.");
|
||||
HWY_ASSERT(IsPacked());
|
||||
}
|
||||
}
|
||||
|
||||
bool HasPtr() const { return ptr_ != nullptr; }
|
||||
|
||||
// Caller has initialized Rows() pointers in row_ptrs[].
|
||||
void AttachRowPtrs(uint8_t** row_ptrs) {
|
||||
row_ptrs_ = row_ptrs;
|
||||
for (size_t r = 0; r < Rows(); ++r) {
|
||||
HWY_DASSERT(row_ptrs[r] != nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t** GetRowPtrs() const { return row_ptrs_; }
|
||||
|
||||
// A single row counts as packed because there is no padding between rows.
|
||||
bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); }
|
||||
|
||||
|
|
@ -195,6 +223,11 @@ class MatPtr : public IFields {
|
|||
// this object.
|
||||
void* ptr_ = nullptr; // not serialized
|
||||
|
||||
// Points to an array of pointers, one per row, or nullptr if `AttachRowPtrs`
|
||||
// was not called. Only used for MatMul output tensors, hence we
|
||||
// minimize the cost for other tensors by only holding a non-owning pointer.
|
||||
uint8_t** row_ptrs_ = nullptr; // not serialized
|
||||
|
||||
// Offset by which to advance pointers to the next row, >= `cols_`.
|
||||
uint32_t stride_;
|
||||
|
||||
|
|
@ -261,6 +294,13 @@ class MatPtrT : public MatPtr {
|
|||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
if (base->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // GEMMA_ENABLE_NUQ
|
||||
|
||||
if (base->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
|
|
@ -270,9 +310,6 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
|||
} else if (base->GetType() == Type::kSFP) {
|
||||
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
|
|
@ -283,6 +320,15 @@ template <class Func, typename... Args>
|
|||
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
||||
const Func& func, Args&&... args) {
|
||||
HWY_ASSERT(base1->GetType() == base2->GetType());
|
||||
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
if (base1->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // GEMMA_ENABLE_NUQ
|
||||
|
||||
if (base1->GetType() == Type::kF32) {
|
||||
return func(dynamic_cast<const MatPtrT<float>*>(base1),
|
||||
dynamic_cast<const MatPtrT<float>*>(base2),
|
||||
|
|
@ -295,10 +341,6 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
|||
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (base1->GetType() == Type::kNUQ) {
|
||||
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
|
||||
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
|
||||
}
|
||||
|
|
@ -384,55 +426,5 @@ class MatStorageT : public MatPtrT<MatT> {
|
|||
MatOwner owner_;
|
||||
};
|
||||
|
||||
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
|
||||
// seekable (non-NUQ) T.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
: row0_(row0),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
stride_(static_cast<uint32_t>(stride)) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
}
|
||||
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||
|
||||
size_t Stride() const { return static_cast<size_t>(stride_); }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||
HWY_DASSERT(c < Cols());
|
||||
HWY_DASSERT(cols <= Cols() - c);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_);
|
||||
}
|
||||
|
||||
private:
|
||||
T* HWY_RESTRICT row0_;
|
||||
uint32_t cols_;
|
||||
uint32_t stride_;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
using RowPtrBF = RowPtr<BF16>;
|
||||
using RowPtrF = RowPtr<float>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromMat(const MatPtrT<T>& row_vectors) {
|
||||
// RowPtr is non-const for MatMul C, but is also used for A which is const.
|
||||
// Callers are responsible for checking their usage of RowPtr.
|
||||
return RowPtr<T>(const_cast<T*>(row_vectors.Row(0)), row_vectors.Cols(),
|
||||
row_vectors.Stride());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue