mirror of https://github.com/google/gemma.cpp.git
Replace last ConstMat with MatPtr
This is to reduce the number of MatMul overloads in preparation for de-templatizing. PiperOrigin-RevId: 758288589
This commit is contained in:
parent
0a6a7e4cd6
commit
38a08d8095
|
|
@ -328,20 +328,20 @@ class GemmaAttention {
|
|||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT q, const ConstMat<float>& k,
|
||||
const float* HWY_RESTRICT q, const MatPtrT<float>& k,
|
||||
float* HWY_RESTRICT att) {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos);
|
||||
const float* HWY_RESTRICT k_ptr = k.Row(pos);
|
||||
const float score = Dot(q, k_ptr, qkv_dim);
|
||||
att[pos] = score;
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos);
|
||||
const float* HWY_RESTRICT k_ptr = k.Row(cache_pos);
|
||||
const float score = Dot(q, k_ptr, qkv_dim);
|
||||
att[pos % activations_.seq_len] = score;
|
||||
}
|
||||
|
|
@ -354,7 +354,7 @@ class GemmaAttention {
|
|||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT att,
|
||||
const ConstMat<float>& v,
|
||||
const MatPtrT<float>& v,
|
||||
float* HWY_RESTRICT att_out) const {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
|
|
@ -362,13 +362,13 @@ class GemmaAttention {
|
|||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos);
|
||||
const float* HWY_RESTRICT v_ptr = v.Row(pos);
|
||||
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos);
|
||||
const float* HWY_RESTRICT v_ptr = v.Row(cache_pos);
|
||||
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
|
||||
qkv_dim);
|
||||
}
|
||||
|
|
@ -378,7 +378,7 @@ class GemmaAttention {
|
|||
public:
|
||||
// Calculates the attention outputs for a single q.
|
||||
HWY_INLINE void SingleDotSoftmaxWeightedSum(
|
||||
float* HWY_RESTRICT q, const ConstMat<float>& k, const ConstMat<float>& v,
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
|
||||
const float query_scale, const size_t pos, const size_t start_pos,
|
||||
const size_t last_pos) {
|
||||
|
|
@ -432,13 +432,14 @@ class GemmaAttention {
|
|||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
const size_t kv_head_offset =
|
||||
layer_ * cache_layer_size_ + head_offset;
|
||||
ConstMat<float> k(kv_cache.kv_cache.get() + kv_head_offset,
|
||||
Extents2D(kv_cache.seq_len, qkv_dim),
|
||||
/*stride=*/cache_pos_size_);
|
||||
ConstMat<float> v(
|
||||
kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||
Extents2D(kv_cache.seq_len, qkv_dim),
|
||||
/*stride=*/cache_pos_size_);
|
||||
MatPtrT<float> k("k_view",
|
||||
Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
||||
/*stride=*/cache_pos_size_);
|
||||
MatPtrT<float> v("v_view",
|
||||
Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||
/*stride=*/cache_pos_size_);
|
||||
|
||||
// Find the token position in the query and calculate the range
|
||||
// of cache positions to attend to.
|
||||
|
|
|
|||
|
|
@ -115,8 +115,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
// Ensure usage conditions are set before autotuning. Both binding and
|
||||
// spinning may materially affect the choice of config. No harm in calling
|
||||
// BindB/C if there is a single package: they will be a no-op.
|
||||
BindB(allocator, B_extents.rows, sizeof(TC), ConstMat<TB>(b_trans),
|
||||
env.parallel);
|
||||
BindB(allocator, sizeof(TC), b_trans, env.parallel);
|
||||
BindC(allocator, A_extents.rows, C, env.parallel);
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
|
|
|
|||
|
|
@ -882,20 +882,17 @@ class MMPerPackage {
|
|||
// B is decompressed several call layers lower, but not all member functions
|
||||
// depend on TB, so pass it as an argument instead of templating the class.
|
||||
template <typename TB, typename TC>
|
||||
HWY_NOINLINE void operator()(const ConstMat<TB>& B,
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B,
|
||||
const RowPtr<TC>& C) const {
|
||||
// TODO: include NUQ tables? NumPacked in ConstMat?
|
||||
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
|
||||
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT(B, num_packed_B, C);
|
||||
return DoNT(B, C);
|
||||
case MMOrder::kNT_K:
|
||||
return DoNT_K(B, num_packed_B, C);
|
||||
return DoNT_K(B, C);
|
||||
case MMOrder::kNT_MT:
|
||||
return DoNT_MT(B, num_packed_B, C);
|
||||
return DoNT_MT(B, C);
|
||||
case MMOrder::kNT_MT_K:
|
||||
return DoNT_MT_K(B, num_packed_B, C);
|
||||
return DoNT_MT_K(B, C);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
|
|
@ -916,8 +913,7 @@ class MMPerPackage {
|
|||
|
||||
// Single M and K, parallel N. Fills all of C directly.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B,
|
||||
const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT", args_);
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
|
|
@ -941,7 +937,7 @@ class MMPerPackage {
|
|||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT.DecB", args_);
|
||||
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
||||
DecompressB(B, row_b, range_K, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||
args_, C);
|
||||
|
|
@ -953,8 +949,7 @@ class MMPerPackage {
|
|||
|
||||
// Single M, parallel N, sequential K. Fills all of partial.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B,
|
||||
const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_K", args_);
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
|
|
@ -977,7 +972,7 @@ class MMPerPackage {
|
|||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
||||
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
||||
DecompressB(B, row_b, range_kc, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C);
|
||||
|
|
@ -1018,8 +1013,7 @@ class MMPerPackage {
|
|||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B,
|
||||
const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT", args_);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
|
|
@ -1042,7 +1036,7 @@ class MMPerPackage {
|
|||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT.DecB", args_);
|
||||
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
||||
DecompressB(B, row_b, range_K, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||
args_, C);
|
||||
|
|
@ -1055,8 +1049,7 @@ class MMPerPackage {
|
|||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B,
|
||||
const RowPtr<TC>& C) const {
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
|
|
@ -1078,7 +1071,7 @@ class MMPerPackage {
|
|||
{
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
||||
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
||||
DecompressB(B, row_b, range_kc, B_view);
|
||||
}
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C);
|
||||
|
|
@ -1210,18 +1203,18 @@ class MMPerPackage {
|
|||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||
// thanks to its large table lookups, and less so on other targets.
|
||||
template <typename TB>
|
||||
HWY_INLINE void DecompressB(const ConstMat<TB>& B, size_t num_packed_B,
|
||||
const size_t row_b, const IndexRange& range_kc,
|
||||
HWY_INLINE void DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const RowPtrBF& B_view) const {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
|
||||
const PackedSpan<const TB> B_span = MakeSpan(B.ptr, num_packed_B);
|
||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||
|
||||
const size_t kc = range_kc.Num();
|
||||
const size_t col0 = range_kc.begin();
|
||||
|
||||
for (size_t r = 0; r < kNR; ++r) {
|
||||
const size_t packed_ofs = B.Row(row_b + r) + col0;
|
||||
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
||||
BF16* HWY_RESTRICT to = B_view.Row(r);
|
||||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||
// Verify that we zero-padded.
|
||||
|
|
@ -1264,7 +1257,7 @@ struct MMImpl {
|
|||
// Called from `MatMul` from two places: either with the next autotune config,
|
||||
// or with the best config.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const RowPtr<TC>& C, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
MMZone matmul_zone;
|
||||
|
|
@ -1300,7 +1293,7 @@ struct MMImpl {
|
|||
//
|
||||
// Uses considerable stack space: at least 40 KiB per thread.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
|
|
@ -1327,8 +1320,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
|||
MMPerKey& per_key = env.per_key[index];
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
|
||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.scale, add,
|
||||
env.storage.Partial());
|
||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add, env.storage.Partial());
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
||||
return &per_key;
|
||||
|
|
@ -1383,13 +1376,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
|||
return &per_key;
|
||||
}
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
return MatMul(A, ConstMat<TB>(B), add, env, C);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
57
ops/matmul.h
57
ops/matmul.h
|
|
@ -681,67 +681,18 @@ struct MMZone {
|
|||
};
|
||||
#endif // PROFILER_ENABLED
|
||||
|
||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||
// This differs from `RowPtr` in supporting the `ofs` required for compressed T.
|
||||
// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr.
|
||||
template <typename T>
|
||||
struct ConstMat {
|
||||
ConstMat() = default;
|
||||
ConstMat(const T* ptr, Extents2D extents, size_t stride)
|
||||
: ptr(ptr), extents(extents), stride(stride), ofs(0) {
|
||||
HWY_DASSERT(ptr != nullptr);
|
||||
HWY_DASSERT(stride >= extents.cols);
|
||||
}
|
||||
// Non-explicit so that we can pass `MatPtr` directly to MatMul.
|
||||
ConstMat(const MatPtrT<T>& m)
|
||||
: ConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride()) {
|
||||
scale = m.Scale();
|
||||
}
|
||||
|
||||
size_t Row(size_t r) const {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (r >= extents.rows) {
|
||||
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
||||
}
|
||||
}
|
||||
return ofs + r * stride;
|
||||
}
|
||||
|
||||
const Extents2D& Extents() const { return extents; }
|
||||
size_t Stride() const { return stride; }
|
||||
float Scale() const { return scale; }
|
||||
// So that matvec-inl.h can use the same interface as MatPtrT:
|
||||
size_t Rows() const { return extents.rows; }
|
||||
size_t Cols() const { return extents.cols; }
|
||||
|
||||
const T* HWY_RESTRICT ptr;
|
||||
Extents2D extents;
|
||||
size_t stride;
|
||||
|
||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||
// values. MatFromWeights sets this from `MatPtr`.
|
||||
float scale = 1.0f;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic. This is in units of weights, and does not have anything
|
||||
// to do with the interleaved NUQ tables. It should be computed via `Row()`
|
||||
// to take into account the stride.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
template <typename TB>
|
||||
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
|
||||
const ConstMat<TB>& B, MMParallel& parallel) {
|
||||
void BindB(const Allocator& allocator, size_t sizeof_TC, const MatPtrT<TB>& B,
|
||||
MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||
const size_t quantum = allocator.Quantum<TB>();
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
uintptr_t begin =
|
||||
reinterpret_cast<uintptr_t>(B.ptr + B.Row(rows_b.begin()));
|
||||
uintptr_t begin = reinterpret_cast<uintptr_t>(B.Row(rows_b.begin()));
|
||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
|
||||
// B is not yet guaranteed to have padded rows, so only bind the
|
||||
// subset that is page-aligned.
|
||||
|
|
|
|||
|
|
@ -47,16 +47,6 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Adapter so that gemma-inl.h can pass ConstMat.
|
||||
// TODO: remove after changing ComputeQKV to MatMul.
|
||||
template <typename WT, typename VT>
|
||||
HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
||||
size_t num) {
|
||||
const hn::ScalableTag<VT> d;
|
||||
HWY_DASSERT(num <= w.Stride()); // Single row, else padding is an issue.
|
||||
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
|
||||
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
|
||||
}
|
||||
// For callers that pass `MatPtrT`, which is not necessarily packed - callers
|
||||
// should use Stride() to compute `w_ofs`.
|
||||
template <typename WT, typename VT>
|
||||
|
|
@ -66,7 +56,7 @@ HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
|||
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// ArrayT is either MatPtrT or ConstMat.
|
||||
// ArrayT is MatPtrT.
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||
// always with addition.
|
||||
|
|
|
|||
Loading…
Reference in New Issue