mirror of https://github.com/google/gemma.cpp.git
Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage. Also do not require specific type for query_norm_scale, update batch sizes for attention tensors, more verbose Mat shape/type checks. PiperOrigin-RevId: 824987689
This commit is contained in:
parent
5a05857deb
commit
3cc0139ebb
|
|
@ -100,6 +100,8 @@ struct AttentionActivations {
|
|||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
|
||||
// `inv_timescale*` are not batched.
|
||||
}
|
||||
|
||||
MatStorageT<float> q; // query
|
||||
|
|
@ -137,6 +139,16 @@ struct AttentionActivationsPtrs {
|
|||
inv_timescale_global = activations.inv_timescale_global;
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
// `inv_timescale*` are not batched.
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
MatPtrT<float> q;
|
||||
MatPtrT<float> q_T;
|
||||
|
|
@ -203,6 +215,9 @@ struct Activations {
|
|||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
attention_storage.SetBatchSize(batch_size);
|
||||
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
|
||||
// their row override is not updated when the underlying storage changes.
|
||||
attention.SetBatchSize(batch_size);
|
||||
}
|
||||
|
||||
const LayerConfig& layer_config;
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
|||
// Updates q in place for RMSNorm and positional encoding.
|
||||
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||
MatPtrT<float>& q,
|
||||
const MatPtrT<float>& query_norm_scale,
|
||||
const MatPtr& query_norm_scale,
|
||||
const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
ThreadingContext& ctx) {
|
||||
|
|
@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
|
|||
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
||||
// 3 modes to use for best efficiency.
|
||||
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||
const size_t layer_idx,
|
||||
const MatPtrT<float>& query_norm_scale,
|
||||
const size_t layer_idx, const MatPtr& query_norm_scale,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ namespace gcpp {
|
|||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
|
|
@ -45,8 +45,7 @@ namespace gcpp {
|
|||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, \
|
||||
const MatPtrT<float>& query_norm_scale, \
|
||||
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
|
|
|
|||
|
|
@ -175,9 +175,14 @@ class GenerateCandidates {
|
|||
|
||||
// The number of A and B columns to read between updating `C`.
|
||||
SizeVec KC(size_t mr, MMOrder order) const {
|
||||
// Must return the actual value: although ignored by `RangesOfKC`, this will
|
||||
// be used in MC() and NC().
|
||||
if (IsOneKC(order)) return SizeVec(1, K_);
|
||||
if (IsOneKC(order)) {
|
||||
// A single KC range is infeasible when K exceeds the max. The caller
|
||||
// will skip all configs with `order`.
|
||||
if (K_ > kMaxKC) return SizeVec();
|
||||
// Must return the actual value: although ignored by `RangesOfKC`, this
|
||||
// will be used in MC() and NC().
|
||||
return SizeVec(1, K_);
|
||||
}
|
||||
// `LoopKC` handles up to `mr` rows of A.
|
||||
const size_t rows_a = HWY_MIN(max_M_, mr);
|
||||
|
||||
|
|
@ -227,13 +232,21 @@ class GenerateCandidates {
|
|||
|
||||
// The number of (L2 resident) A rows for `A2C0` to loop over.
|
||||
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
|
||||
// Must return the actual value: although ignored by `RangesOfMC`, this will
|
||||
// be used in NC().
|
||||
if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_);
|
||||
if (max_M_ <= mr) return SizeVec(1, max_M_);
|
||||
if (IsOneMC(order)) {
|
||||
// A single MC range is infeasible when M exceeds the max. The caller
|
||||
// will skip all configs with `order`.
|
||||
if (max_M_ > kMaxMC) return SizeVec();
|
||||
// Must return the actual value: although ignored by `RangesOfMC`, this
|
||||
// will be used in NC().
|
||||
return SizeVec(1, max_M_);
|
||||
}
|
||||
|
||||
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
|
||||
// it is typically inclusive.
|
||||
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
|
||||
// `kc` was chosen to fit in L1, hence this should not exceed L2.
|
||||
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
|
||||
|
||||
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
||||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||
|
|
@ -242,7 +255,7 @@ class GenerateCandidates {
|
|||
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
|
||||
mc_max = HWY_MIN(mc_max, max_M_);
|
||||
HWY_DASSERT(mc_max != 0);
|
||||
HWY_ASSERT(mc_max != 0);
|
||||
|
||||
SizeVec all_mc;
|
||||
all_mc.reserve(6);
|
||||
|
|
|
|||
|
|
@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||
HWY_DASSERT(activations.SameShape(out));
|
||||
activations.DebugCheckSameShape(out);
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||
|
|
|
|||
15
util/mat.h
15
util/mat.h
|
|
@ -181,7 +181,15 @@ class MatPtr : public IFields {
|
|||
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
|
||||
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
|
||||
bool SameShape(const MatPtr& other) const {
|
||||
return Rows() == other.Rows() && cols_ == other.cols_;
|
||||
return Rows() == other.Rows() && Cols() == other.Cols();
|
||||
}
|
||||
void DebugCheckSameShape(const MatPtr& other) const {
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (!SameShape(other)) {
|
||||
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
|
||||
Rows(), Cols(), other.Rows(), Cols());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Future calls to `Rows()` during this class' lifetime (not serialized)
|
||||
// will return this value. Used to set the actual number of rows for
|
||||
|
|
@ -299,7 +307,10 @@ class MatPtrT : public MatPtr {
|
|||
if (GetType() == Type::kUnknown) {
|
||||
SetType(TypeEnum<MatT>());
|
||||
} else {
|
||||
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
|
||||
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
|
||||
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
|
||||
TypeName<MatT>(), TypeName(other.GetType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue