mirror of https://github.com/google/gemma.cpp.git
Replace last ConstMat with MatPtr
This is to reduce the number of MatMul overloads in preparation for de-templatizing. PiperOrigin-RevId: 758288589
This commit is contained in:
parent
0a6a7e4cd6
commit
38a08d8095
|
|
@ -328,20 +328,20 @@ class GemmaAttention {
|
||||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const float* HWY_RESTRICT q, const ConstMat<float>& k,
|
const float* HWY_RESTRICT q, const MatPtrT<float>& k,
|
||||||
float* HWY_RESTRICT att) {
|
float* HWY_RESTRICT att) {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(pos);
|
const float* HWY_RESTRICT k_ptr = k.Row(pos);
|
||||||
const float score = Dot(q, k_ptr, qkv_dim);
|
const float score = Dot(q, k_ptr, qkv_dim);
|
||||||
att[pos] = score;
|
att[pos] = score;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
const float* HWY_RESTRICT k_ptr = k.ptr + k.Row(cache_pos);
|
const float* HWY_RESTRICT k_ptr = k.Row(cache_pos);
|
||||||
const float score = Dot(q, k_ptr, qkv_dim);
|
const float score = Dot(q, k_ptr, qkv_dim);
|
||||||
att[pos % activations_.seq_len] = score;
|
att[pos % activations_.seq_len] = score;
|
||||||
}
|
}
|
||||||
|
|
@ -354,7 +354,7 @@ class GemmaAttention {
|
||||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||||
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
|
HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos,
|
||||||
const float* HWY_RESTRICT att,
|
const float* HWY_RESTRICT att,
|
||||||
const ConstMat<float>& v,
|
const MatPtrT<float>& v,
|
||||||
float* HWY_RESTRICT att_out) const {
|
float* HWY_RESTRICT att_out) const {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||||
|
|
@ -362,13 +362,13 @@ class GemmaAttention {
|
||||||
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
if (HWY_LIKELY(last_pos < activations_.seq_len)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(pos);
|
const float* HWY_RESTRICT v_ptr = v.Row(pos);
|
||||||
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
|
MulByConstAndAdd(att[pos], v_ptr, att_out, qkv_dim);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
const float* HWY_RESTRICT v_ptr = v.ptr + v.Row(cache_pos);
|
const float* HWY_RESTRICT v_ptr = v.Row(cache_pos);
|
||||||
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
|
MulByConstAndAdd(att[pos % activations_.seq_len], v_ptr, att_out,
|
||||||
qkv_dim);
|
qkv_dim);
|
||||||
}
|
}
|
||||||
|
|
@ -378,7 +378,7 @@ class GemmaAttention {
|
||||||
public:
|
public:
|
||||||
// Calculates the attention outputs for a single q.
|
// Calculates the attention outputs for a single q.
|
||||||
HWY_INLINE void SingleDotSoftmaxWeightedSum(
|
HWY_INLINE void SingleDotSoftmaxWeightedSum(
|
||||||
float* HWY_RESTRICT q, const ConstMat<float>& k, const ConstMat<float>& v,
|
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
|
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out,
|
||||||
const float query_scale, const size_t pos, const size_t start_pos,
|
const float query_scale, const size_t pos, const size_t start_pos,
|
||||||
const size_t last_pos) {
|
const size_t last_pos) {
|
||||||
|
|
@ -432,13 +432,14 @@ class GemmaAttention {
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
const size_t kv_head_offset =
|
const size_t kv_head_offset =
|
||||||
layer_ * cache_layer_size_ + head_offset;
|
layer_ * cache_layer_size_ + head_offset;
|
||||||
ConstMat<float> k(kv_cache.kv_cache.get() + kv_head_offset,
|
MatPtrT<float> k("k_view",
|
||||||
Extents2D(kv_cache.seq_len, qkv_dim),
|
Extents2D(kv_cache.seq_len, qkv_dim));
|
||||||
/*stride=*/cache_pos_size_);
|
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
||||||
ConstMat<float> v(
|
/*stride=*/cache_pos_size_);
|
||||||
kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
MatPtrT<float> v("v_view",
|
||||||
Extents2D(kv_cache.seq_len, qkv_dim),
|
Extents2D(kv_cache.seq_len, qkv_dim));
|
||||||
/*stride=*/cache_pos_size_);
|
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||||
|
/*stride=*/cache_pos_size_);
|
||||||
|
|
||||||
// Find the token position in the query and calculate the range
|
// Find the token position in the query and calculate the range
|
||||||
// of cache positions to attend to.
|
// of cache positions to attend to.
|
||||||
|
|
|
||||||
|
|
@ -115,8 +115,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
// Ensure usage conditions are set before autotuning. Both binding and
|
// Ensure usage conditions are set before autotuning. Both binding and
|
||||||
// spinning may materially affect the choice of config. No harm in calling
|
// spinning may materially affect the choice of config. No harm in calling
|
||||||
// BindB/C if there is a single package: they will be a no-op.
|
// BindB/C if there is a single package: they will be a no-op.
|
||||||
BindB(allocator, B_extents.rows, sizeof(TC), ConstMat<TB>(b_trans),
|
BindB(allocator, sizeof(TC), b_trans, env.parallel);
|
||||||
env.parallel);
|
|
||||||
BindC(allocator, A_extents.rows, C, env.parallel);
|
BindC(allocator, A_extents.rows, C, env.parallel);
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
|
|
|
||||||
|
|
@ -882,20 +882,17 @@ class MMPerPackage {
|
||||||
// B is decompressed several call layers lower, but not all member functions
|
// B is decompressed several call layers lower, but not all member functions
|
||||||
// depend on TB, so pass it as an argument instead of templating the class.
|
// depend on TB, so pass it as an argument instead of templating the class.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_NOINLINE void operator()(const ConstMat<TB>& B,
|
HWY_NOINLINE void operator()(const MatPtrT<TB>& B,
|
||||||
const RowPtr<TC>& C) const {
|
const RowPtr<TC>& C) const {
|
||||||
// TODO: include NUQ tables? NumPacked in ConstMat?
|
|
||||||
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
|
|
||||||
|
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case MMOrder::kNT:
|
case MMOrder::kNT:
|
||||||
return DoNT(B, num_packed_B, C);
|
return DoNT(B, C);
|
||||||
case MMOrder::kNT_K:
|
case MMOrder::kNT_K:
|
||||||
return DoNT_K(B, num_packed_B, C);
|
return DoNT_K(B, C);
|
||||||
case MMOrder::kNT_MT:
|
case MMOrder::kNT_MT:
|
||||||
return DoNT_MT(B, num_packed_B, C);
|
return DoNT_MT(B, C);
|
||||||
case MMOrder::kNT_MT_K:
|
case MMOrder::kNT_MT_K:
|
||||||
return DoNT_MT_K(B, num_packed_B, C);
|
return DoNT_MT_K(B, C);
|
||||||
default:
|
default:
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
}
|
}
|
||||||
|
|
@ -916,8 +913,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Single M and K, parallel N. Fills all of C directly.
|
// Single M and K, parallel N. Fills all of C directly.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B,
|
HWY_INLINE void DoNT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||||
const RowPtr<TC>& C) const {
|
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT", args_);
|
zone.MaybeEnter("MM.NT", args_);
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -941,7 +937,7 @@ class MMPerPackage {
|
||||||
{
|
{
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT.DecB", args_);
|
zone.MaybeEnter("MM.NT.DecB", args_);
|
||||||
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
DecompressB(B, row_b, range_K, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||||
args_, C);
|
args_, C);
|
||||||
|
|
@ -953,8 +949,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Single M, parallel N, sequential K. Fills all of partial.
|
// Single M, parallel N, sequential K. Fills all of partial.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B,
|
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||||
const RowPtr<TC>& C) const {
|
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_K", args_);
|
zone.MaybeEnter("MM.NT_K", args_);
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -977,7 +972,7 @@ class MMPerPackage {
|
||||||
{
|
{
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
||||||
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
DecompressB(B, row_b, range_kc, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
C);
|
C);
|
||||||
|
|
@ -1018,8 +1013,7 @@ class MMPerPackage {
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B,
|
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||||
const RowPtr<TC>& C) const {
|
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT", args_);
|
zone.MaybeEnter("MM.NT_MT", args_);
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
|
|
@ -1042,7 +1036,7 @@ class MMPerPackage {
|
||||||
{
|
{
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT.DecB", args_);
|
zone.MaybeEnter("MM.NT_MT.DecB", args_);
|
||||||
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
DecompressB(B, row_b, range_K, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||||
args_, C);
|
args_, C);
|
||||||
|
|
@ -1055,8 +1049,7 @@ class MMPerPackage {
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B,
|
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const {
|
||||||
const RowPtr<TC>& C) const {
|
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
|
|
@ -1078,7 +1071,7 @@ class MMPerPackage {
|
||||||
{
|
{
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
||||||
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
DecompressB(B, row_b, range_kc, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
C);
|
C);
|
||||||
|
|
@ -1210,18 +1203,18 @@ class MMPerPackage {
|
||||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||||
// thanks to its large table lookups, and less so on other targets.
|
// thanks to its large table lookups, and less so on other targets.
|
||||||
template <typename TB>
|
template <typename TB>
|
||||||
HWY_INLINE void DecompressB(const ConstMat<TB>& B, size_t num_packed_B,
|
HWY_INLINE void DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||||
const size_t row_b, const IndexRange& range_kc,
|
const IndexRange& range_kc,
|
||||||
const RowPtrBF& B_view) const {
|
const RowPtrBF& B_view) const {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
|
||||||
const PackedSpan<const TB> B_span = MakeSpan(B.ptr, num_packed_B);
|
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||||
|
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
const size_t col0 = range_kc.begin();
|
const size_t col0 = range_kc.begin();
|
||||||
|
|
||||||
for (size_t r = 0; r < kNR; ++r) {
|
for (size_t r = 0; r < kNR; ++r) {
|
||||||
const size_t packed_ofs = B.Row(row_b + r) + col0;
|
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
||||||
BF16* HWY_RESTRICT to = B_view.Row(r);
|
BF16* HWY_RESTRICT to = B_view.Row(r);
|
||||||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||||
// Verify that we zero-padded.
|
// Verify that we zero-padded.
|
||||||
|
|
@ -1264,7 +1257,7 @@ struct MMImpl {
|
||||||
// Called from `MatMul` from two places: either with the next autotune config,
|
// Called from `MatMul` from two places: either with the next autotune config,
|
||||||
// or with the best config.
|
// or with the best config.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const RowPtr<TC>& C, const MMArgs& args,
|
const RowPtr<TC>& C, const MMArgs& args,
|
||||||
const MMConfig& config) {
|
const MMConfig& config) {
|
||||||
MMZone matmul_zone;
|
MMZone matmul_zone;
|
||||||
|
|
@ -1300,7 +1293,7 @@ struct MMImpl {
|
||||||
//
|
//
|
||||||
// Uses considerable stack space: at least 40 KiB per thread.
|
// Uses considerable stack space: at least 40 KiB per thread.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
const RowPtr<TC>& C) {
|
const RowPtr<TC>& C) {
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
|
|
@ -1327,8 +1320,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
||||||
MMPerKey& per_key = env.per_key[index];
|
MMPerKey& per_key = env.per_key[index];
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.scale, add,
|
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
env.storage.Partial());
|
add, env.storage.Partial());
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
||||||
return &per_key;
|
return &per_key;
|
||||||
|
|
@ -1383,13 +1376,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TA, typename TB, typename TC>
|
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
|
||||||
const RowPtr<TC>& C) {
|
|
||||||
return MatMul(A, ConstMat<TB>(B), add, env, C);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
57
ops/matmul.h
57
ops/matmul.h
|
|
@ -681,67 +681,18 @@ struct MMZone {
|
||||||
};
|
};
|
||||||
#endif // PROFILER_ENABLED
|
#endif // PROFILER_ENABLED
|
||||||
|
|
||||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
|
||||||
// This differs from `RowPtr` in supporting the `ofs` required for compressed T.
|
|
||||||
// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr.
|
|
||||||
template <typename T>
|
|
||||||
struct ConstMat {
|
|
||||||
ConstMat() = default;
|
|
||||||
ConstMat(const T* ptr, Extents2D extents, size_t stride)
|
|
||||||
: ptr(ptr), extents(extents), stride(stride), ofs(0) {
|
|
||||||
HWY_DASSERT(ptr != nullptr);
|
|
||||||
HWY_DASSERT(stride >= extents.cols);
|
|
||||||
}
|
|
||||||
// Non-explicit so that we can pass `MatPtr` directly to MatMul.
|
|
||||||
ConstMat(const MatPtrT<T>& m)
|
|
||||||
: ConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride()) {
|
|
||||||
scale = m.Scale();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t Row(size_t r) const {
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
if (r >= extents.rows) {
|
|
||||||
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ofs + r * stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Extents2D& Extents() const { return extents; }
|
|
||||||
size_t Stride() const { return stride; }
|
|
||||||
float Scale() const { return scale; }
|
|
||||||
// So that matvec-inl.h can use the same interface as MatPtrT:
|
|
||||||
size_t Rows() const { return extents.rows; }
|
|
||||||
size_t Cols() const { return extents.cols; }
|
|
||||||
|
|
||||||
const T* HWY_RESTRICT ptr;
|
|
||||||
Extents2D extents;
|
|
||||||
size_t stride;
|
|
||||||
|
|
||||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
|
||||||
// values. MatFromWeights sets this from `MatPtr`.
|
|
||||||
float scale = 1.0f;
|
|
||||||
|
|
||||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
|
||||||
// pointer arithmetic. This is in units of weights, and does not have anything
|
|
||||||
// to do with the interleaved NUQ tables. It should be computed via `Row()`
|
|
||||||
// to take into account the stride.
|
|
||||||
size_t ofs;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TB>
|
template <typename TB>
|
||||||
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
|
void BindB(const Allocator& allocator, size_t sizeof_TC, const MatPtrT<TB>& B,
|
||||||
const ConstMat<TB>& B, MMParallel& parallel) {
|
MMParallel& parallel) {
|
||||||
if (!allocator.ShouldBind()) return;
|
if (!allocator.ShouldBind()) return;
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
|
parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||||
const size_t quantum = allocator.Quantum<TB>();
|
const size_t quantum = allocator.Quantum<TB>();
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
uintptr_t begin =
|
uintptr_t begin = reinterpret_cast<uintptr_t>(B.Row(rows_b.begin()));
|
||||||
reinterpret_cast<uintptr_t>(B.ptr + B.Row(rows_b.begin()));
|
|
||||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
|
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
|
||||||
// B is not yet guaranteed to have padded rows, so only bind the
|
// B is not yet guaranteed to have padded rows, so only bind the
|
||||||
// subset that is page-aligned.
|
// subset that is page-aligned.
|
||||||
|
|
|
||||||
|
|
@ -47,16 +47,6 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
// Adapter so that gemma-inl.h can pass ConstMat.
|
|
||||||
// TODO: remove after changing ComputeQKV to MatMul.
|
|
||||||
template <typename WT, typename VT>
|
|
||||||
HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
|
||||||
size_t num) {
|
|
||||||
const hn::ScalableTag<VT> d;
|
|
||||||
HWY_DASSERT(num <= w.Stride()); // Single row, else padding is an issue.
|
|
||||||
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
|
|
||||||
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
|
|
||||||
}
|
|
||||||
// For callers that pass `MatPtrT`, which is not necessarily packed - callers
|
// For callers that pass `MatPtrT`, which is not necessarily packed - callers
|
||||||
// should use Stride() to compute `w_ofs`.
|
// should use Stride() to compute `w_ofs`.
|
||||||
template <typename WT, typename VT>
|
template <typename WT, typename VT>
|
||||||
|
|
@ -66,7 +56,7 @@ HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
||||||
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
|
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ArrayT is either MatPtrT or ConstMat.
|
// ArrayT is MatPtrT.
|
||||||
|
|
||||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||||
// always with addition.
|
// always with addition.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue