mirror of https://github.com/google/gemma.cpp.git
Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage. Also do not require specific type for query_norm_scale, update batch sizes for attention tensors, more verbose Mat shape/type checks. PiperOrigin-RevId: 824987689
This commit is contained in:
parent
5a05857deb
commit
3cc0139ebb
|
|
@ -100,6 +100,8 @@ struct AttentionActivations {
|
||||||
att.OverrideRows(batch_size);
|
att.OverrideRows(batch_size);
|
||||||
att_out.OverrideRows(batch_size);
|
att_out.OverrideRows(batch_size);
|
||||||
att_sums.OverrideRows(batch_size);
|
att_sums.OverrideRows(batch_size);
|
||||||
|
|
||||||
|
// `inv_timescale*` are not batched.
|
||||||
}
|
}
|
||||||
|
|
||||||
MatStorageT<float> q; // query
|
MatStorageT<float> q; // query
|
||||||
|
|
@ -137,6 +139,16 @@ struct AttentionActivationsPtrs {
|
||||||
inv_timescale_global = activations.inv_timescale_global;
|
inv_timescale_global = activations.inv_timescale_global;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetBatchSize(size_t batch_size) {
|
||||||
|
q.OverrideRows(batch_size);
|
||||||
|
// q_T rows are always qkv_dim!
|
||||||
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
|
att.OverrideRows(batch_size);
|
||||||
|
att_out.OverrideRows(batch_size);
|
||||||
|
att_sums.OverrideRows(batch_size);
|
||||||
|
// `inv_timescale*` are not batched.
|
||||||
|
}
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
MatPtrT<float> q;
|
MatPtrT<float> q;
|
||||||
MatPtrT<float> q_T;
|
MatPtrT<float> q_T;
|
||||||
|
|
@ -203,6 +215,9 @@ struct Activations {
|
||||||
ffw_out.OverrideRows(batch_size);
|
ffw_out.OverrideRows(batch_size);
|
||||||
|
|
||||||
attention_storage.SetBatchSize(batch_size);
|
attention_storage.SetBatchSize(batch_size);
|
||||||
|
// `AttentionActivationsPtrs` holds `MatPtrT` which also require updating;
|
||||||
|
// their row override is not updated when the underlying storage changes.
|
||||||
|
attention.SetBatchSize(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||||
// Updates q in place for RMSNorm and positional encoding.
|
// Updates q in place for RMSNorm and positional encoding.
|
||||||
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
MatPtrT<float>& q,
|
MatPtrT<float>& q,
|
||||||
const MatPtrT<float>& query_norm_scale,
|
const MatPtr& query_norm_scale,
|
||||||
const size_t layer_idx,
|
const size_t layer_idx,
|
||||||
const AttentionActivationsPtrs& activations,
|
const AttentionActivationsPtrs& activations,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
|
|
@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
|
||||||
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
||||||
// 3 modes to use for best efficiency.
|
// 3 modes to use for best efficiency.
|
||||||
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
const size_t layer_idx,
|
const size_t layer_idx, const MatPtr& query_norm_scale,
|
||||||
const MatPtrT<float>& query_norm_scale,
|
|
||||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ namespace gcpp {
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void RMSNormAndPositionalEncoding( \
|
void RMSNormAndPositionalEncoding( \
|
||||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||||
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||||
\
|
\
|
||||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||||
|
|
@ -45,8 +45,7 @@ namespace gcpp {
|
||||||
size_t total_tasks, size_t target_parallelism); \
|
size_t total_tasks, size_t target_parallelism); \
|
||||||
\
|
\
|
||||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||||
size_t layer_idx, \
|
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||||
const MatPtrT<float>& query_norm_scale, \
|
|
||||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||||
ThreadingContext& ctx); \
|
ThreadingContext& ctx); \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
|
|
|
||||||
|
|
@ -175,9 +175,14 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// The number of A and B columns to read between updating `C`.
|
// The number of A and B columns to read between updating `C`.
|
||||||
SizeVec KC(size_t mr, MMOrder order) const {
|
SizeVec KC(size_t mr, MMOrder order) const {
|
||||||
// Must return the actual value: although ignored by `RangesOfKC`, this will
|
if (IsOneKC(order)) {
|
||||||
// be used in MC() and NC().
|
// A single KC range is infeasible when K exceeds the max. The caller
|
||||||
if (IsOneKC(order)) return SizeVec(1, K_);
|
// will skip all configs with `order`.
|
||||||
|
if (K_ > kMaxKC) return SizeVec();
|
||||||
|
// Must return the actual value: although ignored by `RangesOfKC`, this
|
||||||
|
// will be used in MC() and NC().
|
||||||
|
return SizeVec(1, K_);
|
||||||
|
}
|
||||||
// `LoopKC` handles up to `mr` rows of A.
|
// `LoopKC` handles up to `mr` rows of A.
|
||||||
const size_t rows_a = HWY_MIN(max_M_, mr);
|
const size_t rows_a = HWY_MIN(max_M_, mr);
|
||||||
|
|
||||||
|
|
@ -227,13 +232,21 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// The number of (L2 resident) A rows for `A2C0` to loop over.
|
// The number of (L2 resident) A rows for `A2C0` to loop over.
|
||||||
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
|
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
|
||||||
// Must return the actual value: although ignored by `RangesOfMC`, this will
|
if (max_M_ <= mr) return SizeVec(1, max_M_);
|
||||||
// be used in NC().
|
if (IsOneMC(order)) {
|
||||||
if (IsOneMC(order) || max_M_ <= mr) return SizeVec(1, max_M_);
|
// A single MC range is infeasible when M exceeds the max. The caller
|
||||||
|
// will skip all configs with `order`.
|
||||||
|
if (max_M_ > kMaxMC) return SizeVec();
|
||||||
|
// Must return the actual value: although ignored by `RangesOfMC`, this
|
||||||
|
// will be used in NC().
|
||||||
|
return SizeVec(1, max_M_);
|
||||||
|
}
|
||||||
|
|
||||||
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
|
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
|
||||||
// it is typically inclusive.
|
// it is typically inclusive.
|
||||||
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
|
const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16));
|
||||||
|
// `kc` was chosen to fit in L1, hence this should not exceed L2.
|
||||||
|
HWY_ASSERT(bytes_b <= cache_.L2Bytes());
|
||||||
|
|
||||||
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
||||||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||||
|
|
@ -242,7 +255,7 @@ class GenerateCandidates {
|
||||||
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||||
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
|
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
|
||||||
mc_max = HWY_MIN(mc_max, max_M_);
|
mc_max = HWY_MIN(mc_max, max_M_);
|
||||||
HWY_DASSERT(mc_max != 0);
|
HWY_ASSERT(mc_max != 0);
|
||||||
|
|
||||||
SizeVec all_mc;
|
SizeVec all_mc;
|
||||||
all_mc.reserve(6);
|
all_mc.reserve(6);
|
||||||
|
|
|
||||||
|
|
@ -497,7 +497,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
size_t cluster_idx = 0) {
|
size_t cluster_idx = 0) {
|
||||||
HWY_DASSERT(weights.Rows() == 1);
|
HWY_DASSERT(weights.Rows() == 1);
|
||||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||||
HWY_DASSERT(activations.SameShape(out));
|
activations.DebugCheckSameShape(out);
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||||
|
|
|
||||||
15
util/mat.h
15
util/mat.h
|
|
@ -181,7 +181,15 @@ class MatPtr : public IFields {
|
||||||
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
|
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
|
||||||
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
|
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
|
||||||
bool SameShape(const MatPtr& other) const {
|
bool SameShape(const MatPtr& other) const {
|
||||||
return Rows() == other.Rows() && cols_ == other.cols_;
|
return Rows() == other.Rows() && Cols() == other.Cols();
|
||||||
|
}
|
||||||
|
void DebugCheckSameShape(const MatPtr& other) const {
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
if (!SameShape(other)) {
|
||||||
|
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
|
||||||
|
Rows(), Cols(), other.Rows(), Cols());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Future calls to `Rows()` during this class' lifetime (not serialized)
|
// Future calls to `Rows()` during this class' lifetime (not serialized)
|
||||||
// will return this value. Used to set the actual number of rows for
|
// will return this value. Used to set the actual number of rows for
|
||||||
|
|
@ -299,7 +307,10 @@ class MatPtrT : public MatPtr {
|
||||||
if (GetType() == Type::kUnknown) {
|
if (GetType() == Type::kUnknown) {
|
||||||
SetType(TypeEnum<MatT>());
|
SetType(TypeEnum<MatT>());
|
||||||
} else {
|
} else {
|
||||||
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
|
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
|
||||||
|
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
|
||||||
|
TypeName<MatT>(), TypeName(other.GetType()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MatPtrT& operator=(const MatPtr& other) {
|
MatPtrT& operator=(const MatPtr& other) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue