diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index a861155..cd10dfb 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -328,20 +328,20 @@ class GemmaAttention { // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT q, const ConstMat& k, + const float* HWY_RESTRICT q, const MatPtrT& k, float* HWY_RESTRICT att) { const size_t qkv_dim = layer_config_.qkv_dim; if (HWY_LIKELY(last_pos < activations_.seq_len)) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos); + const float* HWY_RESTRICT k_ptr = k.Row(pos); const float score = Dot(q, k_ptr, qkv_dim); att[pos] = score; } } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { const size_t cache_pos = div_seq_len_.Remainder(pos); - const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos); + const float* HWY_RESTRICT k_ptr = k.Row(cache_pos); const float score = Dot(q, k_ptr, qkv_dim); att[pos % activations_.seq_len] = score; } @@ -354,7 +354,7 @@ class GemmaAttention { // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos, const float* HWY_RESTRICT att, - const ConstMat& v, + const MatPtrT& v, float* HWY_RESTRICT att_out) const { const size_t qkv_dim = layer_config_.qkv_dim; hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); @@ -362,13 +362,13 @@ class GemmaAttention { if (HWY_LIKELY(last_pos < activations_.seq_len)) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos); + const float* HWY_RESTRICT v_ptr = v.Row(pos); MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim); } } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { const size_t cache_pos = div_seq_len_.Remainder(pos); - const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos); + const float* HWY_RESTRICT v_ptr = v.Row(cache_pos); MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out, qkv_dim); } @@ -378,7 +378,7 @@ class GemmaAttention { public: // Calculates the attention outputs for a single q. HWY_INLINE void SingleDotSoftmaxWeightedSum( - float* HWY_RESTRICT q, const ConstMat& k, const ConstMat& v, + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, const float query_scale, const size_t pos, const size_t start_pos, const size_t last_pos) { @@ -432,13 +432,14 @@ class GemmaAttention { KVCache& kv_cache = kv_caches_[query_idx]; const size_t kv_head_offset = layer_ * cache_layer_size_ + head_offset; - ConstMat k(kv_cache.kv_cache.get() + kv_head_offset, - Extents2D(kv_cache.seq_len, qkv_dim), - /*stride=*/cache_pos_size_); - ConstMat v( - kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, - Extents2D(kv_cache.seq_len, qkv_dim), - /*stride=*/cache_pos_size_); + MatPtrT k("k_view", + Extents2D(kv_cache.seq_len, qkv_dim)); + k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset, + /*stride=*/cache_pos_size_); + MatPtrT v("v_view", + Extents2D(kv_cache.seq_len, qkv_dim)); + v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, + /*stride=*/cache_pos_size_); // Find the token position in the query and calculate the range // of cache positions to attend to. diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 52f19ea..b171180 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -115,8 +115,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(allocator, B_extents.rows, sizeof(TC), ConstMat(b_trans), - env.parallel); + BindB(allocator, sizeof(TC), b_trans, env.parallel); BindC(allocator, A_extents.rows, C, env.parallel); Tristate use_spinning = Tristate::kDefault; diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 4682855..fc245d8 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -882,20 +882,17 @@ class MMPerPackage { // B is decompressed several call layers lower, but not all member functions // depend on TB, so pass it as an argument instead of templating the class. template - HWY_NOINLINE void operator()(const ConstMat& B, + HWY_NOINLINE void operator()(const MatPtrT& B, const RowPtr& C) const { - // TODO: include NUQ tables? NumPacked in ConstMat? - const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows; - switch (order_) { case MMOrder::kNT: - return DoNT(B, num_packed_B, C); + return DoNT(B, C); case MMOrder::kNT_K: - return DoNT_K(B, num_packed_B, C); + return DoNT_K(B, C); case MMOrder::kNT_MT: - return DoNT_MT(B, num_packed_B, C); + return DoNT_MT(B, C); case MMOrder::kNT_MT_K: - return DoNT_MT_K(B, num_packed_B, C); + return DoNT_MT_K(B, C); default: HWY_UNREACHABLE; } @@ -916,8 +913,7 @@ class MMPerPackage { // Single M and K, parallel N. Fills all of C directly. template - HWY_INLINE void DoNT(const ConstMat& B, size_t num_packed_B, - const RowPtr& C) const { + HWY_INLINE void DoNT(const MatPtrT& B, const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -941,7 +937,7 @@ class MMPerPackage { { MMZone zone; zone.MaybeEnter("MM.NT.DecB", args_); - DecompressB(B, num_packed_B, row_b, range_K, B_view); + DecompressB(B, row_b, range_K, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), args_, C); @@ -953,8 +949,7 @@ class MMPerPackage { // Single M, parallel N, sequential K. Fills all of partial. template - HWY_INLINE void DoNT_K(const ConstMat& B, size_t num_packed_B, - const RowPtr& C) const { + HWY_INLINE void DoNT_K(const MatPtrT& B, const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT_K", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -977,7 +972,7 @@ class MMPerPackage { { MMZone zone; zone.MaybeEnter("MM.NT_K.DecB", args_); - DecompressB(B, num_packed_B, row_b, range_kc, B_view); + DecompressB(B, row_b, range_kc, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, C); @@ -1018,8 +1013,7 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. template - HWY_INLINE void DoNT_MT(const ConstMat& B, size_t num_packed_B, - const RowPtr& C) const { + HWY_INLINE void DoNT_MT(const MatPtrT& B, const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT_MT", args_); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -1042,7 +1036,7 @@ class MMPerPackage { { MMZone zone; zone.MaybeEnter("MM.NT_MT.DecB", args_); - DecompressB(B, num_packed_B, row_b, range_K, B_view); + DecompressB(B, row_b, range_K, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), args_, C); @@ -1055,8 +1049,7 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. template - HWY_INLINE void DoNT_MT_K(const ConstMat& B, size_t num_packed_B, - const RowPtr& C) const { + HWY_INLINE void DoNT_MT_K(const MatPtrT& B, const RowPtr& C) const { MMZone zone; zone.MaybeEnter("MM.NT_MT_K", args_); const size_t kc_max = ranges_kc_.TaskSize(); @@ -1078,7 +1071,7 @@ class MMPerPackage { { MMZone zone; zone.MaybeEnter("MM.NT_MT_K.DecB", args_); - DecompressB(B, num_packed_B, row_b, range_kc, B_view); + DecompressB(B, row_b, range_kc, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, C); @@ -1210,18 +1203,18 @@ class MMPerPackage { // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // thanks to its large table lookups, and less so on other targets. template - HWY_INLINE void DecompressB(const ConstMat& B, size_t num_packed_B, - const size_t row_b, const IndexRange& range_kc, + HWY_INLINE void DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, const RowPtrBF& B_view) const { const hn::ScalableTag dbf; - const PackedSpan B_span = MakeSpan(B.ptr, num_packed_B); + const PackedSpan B_span = B.PaddedSpan(); const size_t kc = range_kc.Num(); const size_t col0 = range_kc.begin(); for (size_t r = 0; r < kNR; ++r) { - const size_t packed_ofs = B.Row(row_b + r) + col0; + const size_t packed_ofs = (row_b + r) * B.Stride() + col0; BF16* HWY_RESTRICT to = B_view.Row(r); DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); // Verify that we zero-padded. @@ -1264,7 +1257,7 @@ struct MMImpl { // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. template - static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const ConstMat& B, + static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, const RowPtr& C, const MMArgs& args, const MMConfig& config) { MMZone matmul_zone; @@ -1300,7 +1293,7 @@ struct MMImpl { // // Uses considerable stack space: at least 40 KiB per thread. template -HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const ConstMat& B, +HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, const RowPtr& C) { const Allocator& allocator = env.ctx.allocator; @@ -1327,8 +1320,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const ConstMat& B, MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; - const MMArgs args(env, per_key, static_cast(A.Scale()) * B.scale, add, - env.storage.Partial()); + const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), + add, env.storage.Partial()); if (HWY_LIKELY(tuner.Best())) { MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); return &per_key; @@ -1383,13 +1376,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const ConstMat& B, return &per_key; } -template -HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, - const float* HWY_RESTRICT add, MatMulEnv& env, - const RowPtr& C) { - return MatMul(A, ConstMat(B), add, env, C); -} - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index e00a6ec..2ad369e 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -681,67 +681,18 @@ struct MMZone { }; #endif // PROFILER_ENABLED -// Used for the A and B arguments of `MatMul`, which are always const. -// This differs from `RowPtr` in supporting the `ofs` required for compressed T. -// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr. -template -struct ConstMat { - ConstMat() = default; - ConstMat(const T* ptr, Extents2D extents, size_t stride) - : ptr(ptr), extents(extents), stride(stride), ofs(0) { - HWY_DASSERT(ptr != nullptr); - HWY_DASSERT(stride >= extents.cols); - } - // Non-explicit so that we can pass `MatPtr` directly to MatMul. - ConstMat(const MatPtrT& m) - : ConstMat(const_cast(m.Row(0)), m.Extents(), m.Stride()) { - scale = m.Scale(); - } - - size_t Row(size_t r) const { - if constexpr (HWY_IS_DEBUG_BUILD) { - if (r >= extents.rows) { - HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows); - } - } - return ofs + r * stride; - } - - const Extents2D& Extents() const { return extents; } - size_t Stride() const { return stride; } - float Scale() const { return scale; } - // So that matvec-inl.h can use the same interface as MatPtrT: - size_t Rows() const { return extents.rows; } - size_t Cols() const { return extents.cols; } - - const T* HWY_RESTRICT ptr; - Extents2D extents; - size_t stride; - - // `scale` allows expanding the smaller range of `SfpStream` to the original - // values. MatFromWeights sets this from `MatPtr`. - float scale = 1.0f; - - // Offset to add to `ptr`; separate because T=NuqStream does not support - // pointer arithmetic. This is in units of weights, and does not have anything - // to do with the interleaved NUQ tables. It should be computed via `Row()` - // to take into account the stride. - size_t ofs; -}; - template -void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC, - const ConstMat& B, MMParallel& parallel) { +void BindB(const Allocator& allocator, size_t sizeof_TC, const MatPtrT& B, + MMParallel& parallel) { if (!allocator.ShouldBind()) return; const IndexRangePartition ranges_np = - parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR); + parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR); const size_t quantum = allocator.Quantum(); for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); const size_t node = parallel.Node(pkg_idx); - uintptr_t begin = - reinterpret_cast(B.ptr + B.Row(rows_b.begin())); + uintptr_t begin = reinterpret_cast(B.Row(rows_b.begin())); uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB); // B is not yet guaranteed to have padded rows, so only bind the // subset that is page-aligned. diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 83e70c8..8be84ec 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -47,16 +47,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -// Adapter so that gemma-inl.h can pass ConstMat. -// TODO: remove after changing ComputeQKV to MatMul. -template -HWY_INLINE float Dot(const ConstMat& w, size_t w_ofs, const VT* vec_aligned, - size_t num) { - const hn::ScalableTag d; - HWY_DASSERT(num <= w.Stride()); // Single row, else padding is an issue. - const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride()); - return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num); -} // For callers that pass `MatPtrT`, which is not necessarily packed - callers // should use Stride() to compute `w_ofs`. template @@ -66,7 +56,7 @@ HWY_INLINE float Dot(const MatPtrT& w, size_t w_ofs, const VT* vec_aligned, return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num); } -// ArrayT is either MatPtrT or ConstMat. +// ArrayT is MatPtrT. // Simple version without tiling nor threading, but two offsets/outputs and // always with addition.