Fix excessive KC/MC from prior change

This could lead to stack overflow in B_storage.

Also do not require specific type for query_norm_scale,
update batch sizes for attention tensors,
more verbose Mat shape/type checks.

PiperOrigin-RevId: 824987689
This commit is contained in:
Jan Wassenberg 2025-10-28 05:32:30 -07:00 committed by Copybara-Service
parent 5a05857deb
commit 3cc0139ebb
6 changed files with 53 additions and 16 deletions

View File

@ -100,6 +100,8 @@ struct AttentionActivations {
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
}
MatStorageT<float> q; // query
@ -137,6 +139,16 @@ struct AttentionActivationsPtrs {
inv_timescale_global = activations.inv_timescale_global;
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched.
}
const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
@ -203,6 +215,9 @@ struct Activations {
ffw_out.OverrideRows(batch_size);
attention_storage.SetBatchSize(batch_size);
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
// their row override is not updated when the underlying storage changes.
attention.SetBatchSize(batch_size);
}
const LayerConfig& layer_config;

View File

@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<float>& q,
const MatPtrT<float>& query_norm_scale,
const MatPtr& query_norm_scale,
const size_t layer_idx,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) {
@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx,
const MatPtrT<float>& query_norm_scale,
const size_t layer_idx, const MatPtr& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);

View File

@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
@ -45,8 +45,7 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \

View File

@ -175,9 +175,14 @@ class GenerateCandidates {
// The number of A and B columns to read between updating `C`.
SizeVec KC(size_t mr, MMOrder order) const {
// Must return the actual value: although ignored by `RangesOfKC`, this will
// be used in MC() and NC().
if (IsOneKC(order)) return SizeVec(1, K_);
if (IsOneKC(order)) {
// A single KC range is infeasible when K exceeds the max. The caller
// will skip all configs with `order`.
if (K_ > kMaxKC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfKC`, this
// will be used in MC() and NC().
return SizeVec(1, K_);
}
// `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(max_M_, mr);
@ -227,13 +232,21 @@ class GenerateCandidates {
// The number of (L2 resident) A rows for `A2C0` to loop over.
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
// Must return the actual value: although ignored by `RangesOfMC`, this will
// be used in NC().
if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_);
if (max_M_ <= mr) return SizeVec(1, max_M_);
if (IsOneMC(order)) {
// A single MC range is infeasible when M exceeds the max. The caller
// will skip all configs with `order`.
if (max_M_ > kMaxMC) return SizeVec();
// Must return the actual value: although ignored by `RangesOfMC`, this
// will be used in NC().
return SizeVec(1, max_M_);
}
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
// it is typically inclusive.
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
// `kc` was chosen to fit in L1, hence this should not exceed L2.
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
@ -242,7 +255,7 @@ class GenerateCandidates {
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
mc_max = HWY_MIN(mc_max, max_M_);
HWY_DASSERT(mc_max != 0);
HWY_ASSERT(mc_max != 0);
SizeVec all_mc;
all_mc.reserve(6);

View File

@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
size_t cluster_idx = 0) {
HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out));
activations.DebugCheckSameShape(out);
CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,

View File

@ -181,7 +181,15 @@ class MatPtr : public IFields {
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
bool SameShape(const MatPtr& other) const {
return Rows() == other.Rows() && cols_ == other.cols_;
return Rows() == other.Rows() && Cols() == other.Cols();
}
void DebugCheckSameShape(const MatPtr& other) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (!SameShape(other)) {
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
Rows(), Cols(), other.Rows(), Cols());
}
}
}
// Future calls to `Rows()` during this class' lifetime (not serialized)
// will return this value. Used to set the actual number of rows for
@ -299,7 +307,10 @@ class MatPtrT : public MatPtr {
if (GetType() == Type::kUnknown) {
SetType(TypeEnum<MatT>());
} else {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
TypeName<MatT>(), TypeName(other.GetType()));
}
}
}
MatPtrT& operator=(const MatPtr& other) {