diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 0c76b2d..a40624b 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -517,6 +517,7 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { HWY_BEFORE_TEST(NuqTest); +#if GEMMA_ENABLE_NUQ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); @@ -530,6 +531,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NuqTest); +#endif // GEMMA_ENABLE_NUQ HWY_AFTER_TEST(); } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 034fcea..2ed0916 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -70,12 +70,6 @@ class CompressionTest(absltest.TestCase): info_256.name = "ignored_256" info_256.axes = [0] info_256.shape = [256] - writer.insert( - "tensor_nuq", - np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32), - configs.Type.kNUQ, - info_256, - ) writer.insert( "tensor_sfp", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32), @@ -97,7 +91,7 @@ class CompressionTest(absltest.TestCase): config = configs.ModelConfig( configs.Model.GEMMA_TINY, - configs.Type.kNUQ, + configs.Type.kSFP, configs.PromptWrapping.GEMMA_IT, ) tokenizer_path = "" # no tokenizer required for testing @@ -108,7 +102,7 @@ class CompressionTest(absltest.TestCase): reader = compression.SbsReader(temp_file.full_path) self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY) - self.assertEqual(reader.config.weight, configs.Type.kNUQ) + self.assertEqual(reader.config.weight, configs.Type.kSFP) mat = reader.find_mat("tensor0") self.assertEqual(mat.cols, 192) @@ -128,12 +122,6 @@ class CompressionTest(absltest.TestCase): self.assertEqual(mat.type, configs.Type.kSFP) self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2) - mat = reader.find_mat("tensor_nuq") - self.assertEqual(mat.cols, 256) - self.assertEqual(mat.rows, 1) - self.assertEqual(mat.type, configs.Type.kNUQ) - self.assertAlmostEqual(mat.scale, 1.0) - mat = reader.find_mat("tensor_sfp") self.assertEqual(mat.cols, 256) self.assertEqual(mat.rows, 1) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index d10625c..bddac6b 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -62,7 +62,9 @@ void ForeachPackedAndRawType() { ForeachRawType(); ForeachRawType(); ForeachRawType(); - ForeachRawType(); + if constexpr (GEMMA_ENABLE_NUQ) { + ForeachRawType(); + } } // Generates inputs: deterministic, within max SfpStream range. diff --git a/compression/types.h b/compression/types.h index a699b8f..015560e 100644 --- a/compression/types.h +++ b/compression/types.h @@ -29,6 +29,11 @@ namespace gcpp { +// Only used in experiments, hence disable in default builds. +#ifndef GEMMA_ENABLE_NUQ +#define GEMMA_ENABLE_NUQ 0 +#endif + // Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 // inputs that combines the advantages of e4m3 and e5m2 into a single format. // It supports seeking at a granularity of 1 and decoding to bf16/f32. diff --git a/gemma/activations.h b/gemma/activations.h index a9ba5e7..5c3e99b 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -38,8 +38,7 @@ struct Activations { is_griffin(config.model == Model::GRIFFIN_2B), x("x", Extents2D(batch_size, config.model_dim), pad_), - q("q", - Extents2D(batch_size, layer_config.heads * layer_config.QStride()), + q("q", Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), pad_), logits("logits", Extents2D(batch_size, config.vocab_size), pad_), @@ -82,6 +81,25 @@ struct Activations { env(env) { HWY_ASSERT(batch_size != 0); + // For MatMul outputs, precompute their row pointers. + const auto init_row_ptrs = [&](MatPtrT& mat) { + row_ptrs.push_back(hwy::AllocateAligned(mat.Rows())); + uint8_t** ptrs = row_ptrs.back().get(); + for (size_t r = 0; r < mat.Rows(); ++r) { + ptrs[r] = mat.RowBytes(r); + } + mat.AttachRowPtrs(ptrs); + }; + // If we forget any MatMul outputs here, debug builds print a warning but + // fill them in each MatMul call. + init_row_ptrs(q); + init_row_ptrs(logits); + init_row_ptrs(att_sums); + init_row_ptrs(C1); + init_row_ptrs(C2); + init_row_ptrs(ffw_out); + // TODO: also init rows for image_tokens. + // Note that BindC on any MatMul output considerably slows down Prefill. } @@ -144,6 +162,9 @@ struct Activations { MatStorageT inv_timescale_global; MatMulEnv* env; + // Per-tensor allocations to make it likelier that asan detects bugs such as + // use after free, overrun, and dangling references. + std::vector> row_ptrs; }; } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index caffc81..db3941d 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -259,16 +259,12 @@ struct LayerConfig : public IFields { // Multi-Head Attention? bool IsMHA() const { return heads == kv_heads; } - // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, - // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); } - uint32_t model_dim = 0; uint32_t griffin_dim = 0; uint32_t ff_hidden_dim = 0; uint32_t heads = 0; uint32_t kv_heads = 0; - uint32_t qkv_dim = 0; + uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). uint32_t conv1d_width = 0; // Griffin only bool ff_biases = false; bool softmax_attn_output_biases = false; // for Griffin diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index bd80e17..3f53134 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -249,64 +249,38 @@ class GemmaAttention { } } - // Fills activations.q and computes KV. For is_mha_, a single MatMul suffices - // and we later copy KV from q to KVCache. Otherwise, a second MatMul writes - // KV directly to KVCache. + // Fills activations.q and writes to KV cache. HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.QKV"); - const size_t model_dim = layer_config_.model_dim; const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; const size_t kv_heads = layer_config_.kv_heads; - // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, - // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. - // We must shrink to the actual size because MatMul verifies - // `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all - // rows are used. Otherwise, `QStride() == qkv_dim` and KV will be - // computed in the second MatMul. - const size_t w1_rows = heads * layer_config_.QStride(); - HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows); + // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, + // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, - /*add=*/nullptr, *activations_.env, - RowPtrFromMat(activations_.q)); + /*add=*/nullptr, *activations_.env, activations_.q); - if (is_mha_) { - // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. - } else { - // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). - const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; - HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols); - - // Single query and no wraparound means we can use a matmul and write - // directly into the KV cache with a stride of cache_pos_size_. - if (num_queries_ == 1 && - queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) { - const size_t kv_ofs = - queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; - float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; - RowPtrF kv_rows(kv, w_rows_kv_cols); - kv_rows.SetStride(cache_pos_size_); - MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, - /*add=*/nullptr, *activations_.env, kv_rows); - } else { - // Proceed row by row because there will be wraparound. - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const float* x = activations_.pre_att_rms_out.Row(interleaved_idx); - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - KVCache& kv_cache = kv_caches_[query_idx]; - const size_t cache_pos = - div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx); - const size_t kv_offset = - cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x, - kv, pool_); - } - } - } // !is_mha_ + // Set up MatMul row pointers for writing to KV, which consists of + // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound + // because rows are computed modulo seq_len. + MatPtrT kv_rows("kv", + Extents2D(activations_.pre_att_rms_out.Rows(), + layer_weights_.qkv_einsum_w2.Rows())); + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t query_idx = interleaved_idx % num_queries_; + const size_t batch_idx = interleaved_idx / num_queries_; + const size_t cache_pos = + div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx); + const size_t kv_offset = + cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; + activations_.env->storage.OutRow(interleaved_idx) = + reinterpret_cast(kv_caches_[query_idx].kv_cache.get() + + kv_offset); + } + kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0)); + MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, + /*add=*/nullptr, *activations_.env, kv_rows); // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, kv_heads * num_interleaved, @@ -322,13 +296,6 @@ class GemmaAttention { head * qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - // If MHA, copy computed K and V into KVCache. - if (is_mha_) { - const float* HWY_RESTRICT mha_kv = - activations_.q.Row(interleaved_idx) + head * q_stride_ + - qkv_dim; - hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv)); - } // Apply further processing to K. if (layer_weights_.key_norm_scale.HasPtr()) { @@ -435,7 +402,7 @@ class GemmaAttention { const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; float* HWY_RESTRICT q = - activations_.q.Row(interleaved_idx) + head * q_stride_; + activations_.q.Row(interleaved_idx) + head * qkv_dim; float* HWY_RESTRICT att = activations_.att.Row(interleaved_idx) + head * activations_.seq_len; float* HWY_RESTRICT att_out = @@ -490,7 +457,7 @@ class GemmaAttention { ? layer_weights_.attention_output_biases.PackedScale1() : nullptr; MatMulStatic(activations_.att_out, layer_weights_.att_weights, add, - *activations_.env, RowPtrFromMat(activations_.att_sums)); + *activations_.env, activations_.att_sums); } public: @@ -548,15 +515,14 @@ class GemmaAttention { num_tokens_(num_tokens), layer_(layer), layer_config_(layer_weights->layer_config), - q_stride_(layer_config_.QStride()), cache_layer_size_(layer_weights->layer_config.CacheLayerSize()), cache_pos_size_(activations.cache_pos_size), - is_mha_(layer_config_.IsMHA()), activations_(activations), layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), pool_(ctx.pools.Pool(0)) { + HWY_DASSERT(!layer_config_.IsMHA()); // No longer supported. HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, "query heads must be a multiple of key-value heads"); @@ -576,10 +542,8 @@ class GemmaAttention { const size_t num_tokens_; const size_t layer_; const LayerConfig& layer_config_; - const size_t q_stride_ = 0; const size_t cache_layer_size_ = 0; const size_t cache_pos_size_ = 0; - const bool is_mha_ = false; Activations& activations_; const LayerWeightsPtrs& layer_weights_; @@ -627,7 +591,7 @@ class VitAttention { HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, layer_weights_.vit.qkv_einsum_b.PackedScale1(), - *activations_.env, RowPtrFromMat(qkv)); + *activations_.env, qkv); } // TODO(philculliton): transition fully to MatMul. @@ -667,7 +631,7 @@ class VitAttention { }); // this produces C, a (num_tokens_, seq_len) matrix of dot products - MatMulStatic(Q, K, nullptr, *activations_.env, RowPtrFromMat(C)); + MatMulStatic(Q, K, nullptr, *activations_.env, C); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); @@ -733,9 +697,8 @@ class VitAttention { // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - auto att_sums = RowPtrFromMat(activations_.att_sums); MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias, - *activations_.env, att_sums); + *activations_.env, activations_.att_sums); } public: @@ -827,9 +790,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, // Compute the hidden layer activations. MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, - bias1, *activations.env, RowPtrFromMat(activations.C1)); + bias1, *activations.env, activations.C1); MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, - bias2, *activations.env, RowPtrFromMat(activations.C2)); + bias2, *activations.env, activations.C2); // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_weights->layer_config.activation, activations.C1, @@ -837,7 +800,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, // Hidden layer -> output layer. MatMulStatic(activations.C1, layer_weights->linear_w, output_bias, - *activations.env, RowPtrFromMat(activations.ffw_out)); + *activations.env, activations.ffw_out); } // Same as FFWNoVit, but with different layer_weights members and no second @@ -855,14 +818,14 @@ HWY_NOINLINE void FFWVit(Activations& activations, // Compute the hidden layer activations. MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, - bias1, *activations.env, RowPtrFromMat(activations.C1)); + bias1, *activations.env, activations.C1); // Activation (Gelu), store in C1. ActivationBatched(layer_weights->layer_config.activation, activations.C1); // Hidden layer -> output layer. MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias, - *activations.env, RowPtrFromMat(activations.ffw_out)); + *activations.env, activations.ffw_out); } // `batch_idx` indicates which row of `x` to write to. @@ -1176,10 +1139,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // kPatchSize), MatPadding::kPacked); // [Get patches] // MatMulStatic( - // MatFromBatch(kVitSeqLen, image_patches), - // MatFromWeights(weights.vit_img_embedding_kernel), + // image_patches, + // weights.vit_img_embedding_kernel, // weights.vit_img_embedding_bias.PackedScale1(), *activations.env, - // RowPtrF(activations.x.Row(0), kVitModelDim)); + // activations.x); // However, MatMul currently requires that // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 // which is not the case here. We should relax that requirement on MatMul and @@ -1228,7 +1191,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, // Apply head embedding into image_tokens of size of the LLM kModelDim. MatMulStatic(activations.x, weights.vit_img_head_kernel, weights.vit_img_head_bias.PackedScale1(), *activations.env, - RowPtrFromMat(image_tokens)); + image_tokens); } // Generates one token for each query. `queries_token` is the previous token @@ -1367,8 +1330,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. MatMulStatic(activations.x, weights.embedder_input_embedding, - /*add=*/nullptr, *activations.env, - RowPtrFromMat(activations.logits)); + /*add=*/nullptr, *activations.env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { diff --git a/gemma/weights.h b/gemma/weights.h index 53be5cd..3173cb2 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -393,14 +393,7 @@ struct LayerWeightsPtrs { // MHA, and otherwise might not be the same type. if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; - const size_t w1_rows = layer_config.heads * layer_config.QStride(); - - if (layer_config.IsMHA()) { // MHA only requires w1. - qkv_einsum_w1 = qkv_einsum_w; - HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); - return; - } - + const size_t w1_rows = layer_config.heads * layer_config.qkv_dim; const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); diff --git a/io/io.cc b/io/io.cc index 8c6d484..eea3bd6 100644 --- a/io/io.cc +++ b/io/io.cc @@ -15,10 +15,6 @@ // Safe to be first, does not include POSIX headers. #include "hwy/detect_compiler_arch.h" -// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to -// check this in source code because we support multiple build systems. -#if !HWY_OS_WIN - // Request POSIX 2008, including `pread()` and `posix_fadvise()`. This also // implies `_POSIX_C_SOURCE`. #if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 @@ -30,6 +26,14 @@ #undef _FILE_OFFSET_BITS #define _FILE_OFFSET_BITS 64 +#include + +#include +#include + +#include "io/io.h" +#include "hwy/base.h" // HWY_ASSERT + #if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \ (!defined(__ANDROID_API__) || __ANDROID_API__ >= 24) #define GEMMA_IO_PREADV 1 @@ -44,6 +48,11 @@ #define GEMMA_IO_FADVISE 0 #endif +// FilePosix should only be compiled on non-Windows. It is easier to +// check this in source code because we support multiple build systems. Note +// that IOBatch at the end of this TU is still compiled on all platforms. +#if !HWY_OS_WIN + #if GEMMA_IO_PREADV // Replacement for the _BSD_SOURCE specified by preadv documentation. #ifndef _DEFAULT_SOURCE @@ -55,7 +64,6 @@ #include // open #include // IOV_MAX -#include #include #include // SEEK_END - unistd isn't enough for IDE. #include @@ -64,12 +72,7 @@ #include // O_RDONLY #include // read, write, close -#include -#include - -#include "io/io.h" #include "util/allocator.h" -#include "hwy/base.h" // HWY_ASSERT namespace gcpp { @@ -168,6 +171,12 @@ std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode) { return std::make_unique(fd); } +} // namespace gcpp + +#endif // !HWY_OS_WIN + +namespace gcpp { + std::unique_ptr OpenFileOrAbort(const Path& filename, const char* mode) { std::unique_ptr file = OpenFileOrNull(filename, "r"); if (!file) { @@ -237,4 +246,3 @@ uint64_t IOBatch::Read(const File& file) const { } } // namespace gcpp -#endif // !HWY_OS_WIN diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 1f6aa19..b60a258 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -91,8 +91,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - MatStorageT c_slow_mat("c_slow_batch", C_extents, MatPadding::kOdd); - MatStorageT c_mat("c_batch", C_extents, MatPadding::kOdd); + MatStorageT C_slow("c_slow_batch", C_extents, MatPadding::kOdd); + MatStorageT C("c_batch", C_extents, MatPadding::kOdd); MatStorageT add_storage("add", Extents2D(), MatPadding::kPacked); if (add) { @@ -104,7 +104,6 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { MatStorageT b_trans = GenerateTransposedMat(B_extents, pool); const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C = RowPtrFromMat(c_mat); // Fewer reps for large batch sizes, which take longer. const size_t num_samples = M < 32 ? 20 : 12; @@ -115,7 +114,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. BindB(b_trans, sizeof(TC), env.parallel); - BindC(c_mat, env.parallel); + BindC(C, env.parallel); Tristate use_spinning = Tristate::kDefault; env.ctx.pools.MaybeStartSpinning(use_spinning); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index ddcddd5..5c84f8e 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -80,6 +80,7 @@ hn::Vec TCFromF32(DC dc, hn::Vec vf) { return hn::DemoteTo(dc, vf); } +// Type-safe wrapper over uint8_t row pointers referenced by MatPtrT. template class CRows { public: @@ -1183,7 +1184,10 @@ class MMPerPackage { if constexpr (hwy::IsSame()) { // Only if no zero-padding required. const size_t NBF = hn::Lanes(hn::ScalableTag()); - if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(A); + if (HWY_LIKELY(A.Cols() % NBF == 0)) { + // Actually const, but RowPtr is also used for partial which is not. + return RowPtrBF(const_cast(A.Row(0)), A.Cols(), A.Stride()); + } } if (HWY_LIKELY(autotune.Best())) { @@ -1312,7 +1316,21 @@ struct MMImpl { template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, - CRows C_rows) { + MatPtrT& C) { + CRows C_rows(C.GetRowPtrs()); + if (HWY_UNLIKELY(!C.GetRowPtrs())) { + if constexpr (HWY_IS_DEBUG_BUILD) { + fprintf(stderr, + "MatMul perf warning: setting row pointers because " + "C.AttachRowPtrs() was not called.\n"); + } + HWY_DASSERT(C.HasPtr()); + for (size_t r = 0; r < C.Rows(); ++r) { + env.storage.OutRow(r) = reinterpret_cast(C.Row(r)); + } + C_rows = CRows(&env.storage.OutRow(0)); + } + const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); const size_t K = A.Cols(); @@ -1392,19 +1410,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, return &per_key; } -// Adapter that fills the row array. This is the common case, whereas only -// GemmaAttention::ComputeQKV uses the arbitrary output rows feature. -template -HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, - const float* HWY_RESTRICT add, MatMulEnv& env, - const RowPtr& C) { - HWY_DASSERT(B.Rows() == C.Cols()); - for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) { - env.storage.OutRow(row_ac) = reinterpret_cast(C.Row(row_ac)); - } - return MatMul(A, B, add, env, CRows(&env.storage.OutRow(0))); -} - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index ee6037b..06cb3f1 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -176,6 +176,44 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); // C is BF16/float, or double for partial. void BindC(MatPtr& C, MMParallel& parallel); +// Lightweight view into `MatStorageT`. +#pragma pack(push, 1) // power of two size +template +class RowPtr { + public: + RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), + cols_(static_cast(cols)), + stride_(static_cast(stride)) { + HWY_DASSERT(stride >= cols); + } + + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } + size_t Cols() const { return static_cast(cols_); } + + size_t Stride() const { return static_cast(stride_); } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + RowPtr View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < Cols()); + HWY_DASSERT(cols <= Cols() - c); + return RowPtr(Row(r) + c, cols, stride_); + } + + private: + T* HWY_RESTRICT row0_; + uint32_t cols_; + uint32_t stride_; +}; +#pragma pack(pop) + +using RowPtrBF = RowPtr; +using RowPtrD = RowPtr; + // Per-package storage for packed A, and one global C-shaped `partial` for // accumulating partial dot products (sections of K). class MMStorage { diff --git a/ops/matmul_static-inl.h b/ops/matmul_static-inl.h index da17c51..28b21cf 100644 --- a/ops/matmul_static-inl.h +++ b/ops/matmul_static-inl.h @@ -28,7 +28,7 @@ #define GEMMA_MATMUL_DEFINE_ONE(TA, TB, TC) \ MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ const float* HWY_RESTRICT add, MatMulEnv& env, \ - const RowPtr& C) { \ + MatPtrT& C) { \ return MatMul(A, B, add, env, C); \ } diff --git a/ops/matmul_static.h b/ops/matmul_static.h index e16d340..c06b87a 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -35,7 +35,7 @@ #define GEMMA_MATMUL_DECL_ONE(TA, TB, TC) \ MMPerKey* MatMulStatic(const MatPtrT& A, const MatPtrT& B, \ const float* HWY_RESTRICT add, MatMulEnv& env, \ - const RowPtr& C); + MatPtrT& C); // Passed to HWY_VISIT_TARGETS; declares all overloads for all targets. #define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 69ecc6e..6d3cf54 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -91,7 +91,7 @@ float MaxAbs(const MatStorageT& a) { // B is already transposed. template void AssertClose(const MatPtrT& A, const MatPtrT& B, - const RowPtr& C_slow, const RowPtr& C, int line) { + const MatPtrT& C_slow, const MatPtrT& C, int line) { const hn::ScalableTag df; const size_t cols = A.Cols(); const size_t B_rows = B.Rows(); @@ -161,7 +161,7 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, template HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, const float* HWY_RESTRICT add_row, MatMulEnv& env, - const RowPtr& C) { + MatPtrT& C) { // TA can be any Packed except NuqStream because it uses pointer // arithmetic, because it is the second argument to Dot, which does not // support a v_ofs. @@ -223,25 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D C_extents(rows_ac, cols_bc); - MatStorageT a(GenerateMat(A_extents, pool)); - MatStorageT b_trans(GenerateTransposedMat(B_extents, pool)); - MatStorageT c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd); - MatStorageT c_batch("c_batch", C_extents, MatPadding::kOdd); + MatStorageT A(GenerateMat(A_extents, pool)); + MatStorageT BT(GenerateTransposedMat(B_extents, pool)); + MatStorageT C_slow("c_slow_batch", C_extents, MatPadding::kOdd); + MatStorageT C("c_batch", C_extents, MatPadding::kOdd); MatStorageT add_storage = add ? GenerateMat(Extents2D(1, cols_bc), pool) : MatStorageT("add", Extents2D(), MatPadding::kPacked); add_storage.SetScale(1.0f); - const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C_slow = RowPtrFromMat(c_slow_batch); - const RowPtr C = RowPtrFromMat(c_batch); - MatMulSlow(a, b_trans, add_row, env, C_slow); + MatMulSlow(A, BT, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. for (size_t rep = 0; rep < 16; ++rep) { - MMPerKey* per_key = MatMulStatic(a, b_trans, add_row, env, C); - AssertClose(a, b_trans, C_slow, C, line); + MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C); + AssertClose(A, BT, C_slow, C, line); if (per_key->autotune.Best()) break; } } diff --git a/util/mat.h b/util/mat.h index a3de89b..5b15df9 100644 --- a/util/mat.h +++ b/util/mat.h @@ -33,6 +33,18 @@ namespace gcpp { +// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. +template +class CRows { + public: + CRows(TC** C_rows) : C_rows_(C_rows) {} + + TC* HWY_RESTRICT operator[](size_t row_idx) const { return C_rows_[row_idx]; } + + private: + TC** C_rows_; +}; + // Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector // or matrix). Base class of the non-type-erased `MatPtrT`. Use this class // to store hetereogeneous tensor references in a vector. @@ -63,13 +75,29 @@ class MatPtr : public IFields { ptr_ = ptr; stride_ = static_cast(stride); + // If row pointers were already attached, `SetPtr` would invalidate them. + HWY_DASSERT_M(row_ptrs_ == nullptr, "Do not call after AttachRowPtrs."); + // NUQ streams must not be padded because that would change the position of // the group tables. - if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked()); + if (type_ == Type::kNUQ) { + HWY_ASSERT_M(GEMMA_ENABLE_NUQ, "Set GEMMA_ENABLE_NUQ=1."); + HWY_ASSERT(IsPacked()); + } } bool HasPtr() const { return ptr_ != nullptr; } + // Caller has initialized Rows() pointers in row_ptrs[]. + void AttachRowPtrs(uint8_t** row_ptrs) { + row_ptrs_ = row_ptrs; + for (size_t r = 0; r < Rows(); ++r) { + HWY_DASSERT(row_ptrs[r] != nullptr); + } + } + + uint8_t** GetRowPtrs() const { return row_ptrs_; } + // A single row counts as packed because there is no padding between rows. bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); } @@ -195,6 +223,11 @@ class MatPtr : public IFields { // this object. void* ptr_ = nullptr; // not serialized + // Points to an array of pointers, one per row, or nullptr if `AttachRowPtrs` + // was not called. Only used for MatMul output tensors, hence we + // minimize the cost for other tensors by only holding a non-owning pointer. + uint8_t** row_ptrs_ = nullptr; // not serialized + // Offset by which to advance pointers to the next row, >= `cols_`. uint32_t stride_; @@ -261,6 +294,13 @@ class MatPtrT : public MatPtr { template decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, Args&&... args) { +#if GEMMA_ENABLE_NUQ + if (base->GetType() == Type::kNUQ) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } +#endif // GEMMA_ENABLE_NUQ + if (base->GetType() == Type::kF32) { return func(dynamic_cast*>(base), std::forward(args)...); @@ -270,9 +310,6 @@ decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, } else if (base->GetType() == Type::kSFP) { return func(dynamic_cast*>(base), std::forward(args)...); - } else if (base->GetType() == Type::kNUQ) { - return func(dynamic_cast*>(base), - std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); } @@ -283,6 +320,15 @@ template decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, const Func& func, Args&&... args) { HWY_ASSERT(base1->GetType() == base2->GetType()); + +#if GEMMA_ENABLE_NUQ + if (base1->GetType() == Type::kNUQ) { + return func(dynamic_cast*>(base1), + dynamic_cast*>(base2), + std::forward(args)...); + } +#endif // GEMMA_ENABLE_NUQ + if (base1->GetType() == Type::kF32) { return func(dynamic_cast*>(base1), dynamic_cast*>(base2), @@ -295,10 +341,6 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, return func(dynamic_cast*>(base1), dynamic_cast*>(base2), std::forward(args)...); - } else if (base1->GetType() == Type::kNUQ) { - return func(dynamic_cast*>(base1), - dynamic_cast*>(base2), - std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType())); } @@ -384,55 +426,5 @@ class MatStorageT : public MatPtrT { MatOwner owner_; }; -// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with -// seekable (non-NUQ) T. -#pragma pack(push, 1) // power of two size -template -class RowPtr { - public: - RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), - cols_(static_cast(cols)), - stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); - } - - RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} - - T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } - size_t Cols() const { return static_cast(cols_); } - - size_t Stride() const { return static_cast(stride_); } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - RowPtr View(size_t r, size_t c, size_t cols) const { - HWY_DASSERT(c < Cols()); - HWY_DASSERT(cols <= Cols() - c); - return RowPtr(Row(r) + c, cols, stride_); - } - - private: - T* HWY_RESTRICT row0_; - uint32_t cols_; - uint32_t stride_; -}; -#pragma pack(pop) - -using RowPtrBF = RowPtr; -using RowPtrF = RowPtr; -using RowPtrD = RowPtr; - -template -RowPtr RowPtrFromMat(const MatPtrT& row_vectors) { - // RowPtr is non-const for MatMul C, but is also used for A which is const. - // Callers are responsible for checking their usage of RowPtr. - return RowPtr(const_cast(row_vectors.Row(0)), row_vectors.Cols(), - row_vectors.Stride()); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_