Fix excessive KC/MC from prior change

This could lead to stack overflow in B_storage.

Also do not require specific type for query_norm_scale,
update batch sizes for attention tensors,
more verbose Mat shape/type checks.

PiperOrigin-RevId: 824987689
This commit is contained in:
Jan Wassenberg 2025-10-28 05:32:30 -07:00 committed by Copybara-Service
parent 5a05857deb
commit 3cc0139ebb
6 changed files with 53 additions and 16 deletions

View File

@ -100,6 +100,8 @@ struct AttentionActivations {
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size); att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size); att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
} }
MatStorageT<float> q; // query MatStorageT<float> q; // query
@ -137,6 +139,16 @@ struct AttentionActivationsPtrs {
inv_timescale_global = activations.inv_timescale_global; inv_timescale_global = activations.inv_timescale_global;
} }
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
}
const ModelConfig& config; const ModelConfig& config;
MatPtrT<float> q; MatPtrT<float> q;
MatPtrT<float> q_T; MatPtrT<float> q_T;
@ -203,6 +215,9 @@ struct Activations {
ffw_out.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size);
attention_storage.SetBatchSize(batch_size); attention_storage.SetBatchSize(batch_size);
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
// their row override is not updated when the underlying storage changes.
attention.SetBatchSize(batch_size);
} }
const LayerConfig& layer_config; const LayerConfig& layer_config;

View File

@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding. // Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<float>& q, MatPtrT<float>& q,
const MatPtrT<float>& query_norm_scale, const MatPtr& query_norm_scale,
const size_t layer_idx, const size_t layer_idx,
const AttentionActivationsPtrs& activations, const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) { ThreadingContext& ctx) {
@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the // grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency. // 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism, void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const size_t layer_idx, const MatPtr& query_norm_scale,
const MatPtrT<float>& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch, AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);

View File

@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \ namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \ void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \ size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\ \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \
@ -45,8 +45,7 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \
\ \
void FlashAttention(size_t num_tokens, size_t target_parallelism, \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, \ size_t layer_idx, const MatPtr& query_norm_scale, \
const MatPtrT<float>& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \

View File

@ -175,9 +175,14 @@ class GenerateCandidates {
// The number of A and B columns to read between updating `C`. // The number of A and B columns to read between updating `C`.
SizeVec KC(size_t mr, MMOrder order) const { SizeVec KC(size_t mr, MMOrder order) const {
// Must return the actual value: although ignored by `RangesOfKC`, this will if (IsOneKC(order)) {
// be used in MC() and NC(). // A single KC range is infeasible when K exceeds the max. The caller
if (IsOneKC(order)) return SizeVec(1, K_); // will skip all configs with `order`.
if (K_ > kMaxKC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfKC`, this
// will be used in MC() and NC().
return SizeVec(1, K_);
}
// `LoopKC` handles up to `mr` rows of A. // `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(max_M_, mr); const size_t rows_a = HWY_MIN(max_M_, mr);
@ -227,13 +232,21 @@ class GenerateCandidates {
// The number of (L2 resident) A rows for `A2C0` to loop over. // The number of (L2 resident) A rows for `A2C0` to loop over.
SizeVec MC(size_t mr, size_t kc, MMOrder order) const { SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
// Must return the actual value: although ignored by `RangesOfMC`, this will if (max_M_ <= mr) return SizeVec(1, max_M_);
// be used in NC(). if (IsOneMC(order)) {
if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_); // A single MC range is infeasible when M exceeds the max. The caller
// will skip all configs with `order`.
if (max_M_ > kMaxMC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfMC`, this
// will be used in NC().
return SizeVec(1, max_M_);
}
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because // Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
// it is typically inclusive. // it is typically inclusive.
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16)); const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
// `kc` was chosen to fit in L1, hence this should not exceed L2.
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside // packed B. We want `mc * kc` elements of A to fit in L2, alongside
@ -242,7 +255,7 @@ class GenerateCandidates {
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC)); mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
mc_max = HWY_MIN(mc_max, max_M_); mc_max = HWY_MIN(mc_max, max_M_);
HWY_DASSERT(mc_max != 0); HWY_ASSERT(mc_max != 0);
SizeVec all_mc; SizeVec all_mc;
all_mc.reserve(6); all_mc.reserve(6);

View File

@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
size_t cluster_idx = 0) { size_t cluster_idx = 0) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out)); activations.DebugCheckSameShape(out);
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,

View File

@ -181,7 +181,15 @@ class MatPtr : public IFields {
Extents2D Extents() const { return Extents2D(Rows(), cols_); } Extents2D Extents() const { return Extents2D(Rows(), cols_); }
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; } bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
bool SameShape(const MatPtr& other) const { bool SameShape(const MatPtr& other) const {
return Rows() == other.Rows() && cols_ == other.cols_; return Rows() == other.Rows() && Cols() == other.Cols();
}
void DebugCheckSameShape(const MatPtr& other) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (!SameShape(other)) {
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
Rows(), Cols(), other.Rows(), Cols());
}
}
} }
// Future calls to `Rows()` during this class' lifetime (not serialized) // Future calls to `Rows()` during this class' lifetime (not serialized)
// will return this value. Used to set the actual number of rows for // will return this value. Used to set the actual number of rows for
@ -299,7 +307,10 @@ class MatPtrT : public MatPtr {
if (GetType() == Type::kUnknown) { if (GetType() == Type::kUnknown) {
SetType(TypeEnum<MatT>()); SetType(TypeEnum<MatT>());
} else { } else {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>()); if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
TypeName<MatT>(), TypeName(other.GetType()));
}
} }
} }
MatPtrT& operator=(const MatPtr& other) { MatPtrT& operator=(const MatPtr& other) {