mirror of https://github.com/google/gemma.cpp.git
Change (old) attention behavior to disallow wraparound, enforced via assertion.
Shared kU64PerLine constant PiperOrigin-RevId: 828072451
This commit is contained in:
parent
3a63a12624
commit
a344a70c59
|
|
@ -159,6 +159,10 @@ struct AttentionActivationsPtrs {
|
||||||
// `inv_timescale*` are not batched.
|
// `inv_timescale*` are not batched.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t SeqLen() const {
|
||||||
|
return static_cast<size_t>(div_seq_len.GetDivisor());
|
||||||
|
}
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
MatPtrT<float> q;
|
MatPtrT<float> q;
|
||||||
MatPtrT<BF16> q_bf;
|
MatPtrT<BF16> q_bf;
|
||||||
|
|
|
||||||
|
|
@ -66,20 +66,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
|
||||||
0);
|
0);
|
||||||
|
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
// Slightly faster: no wraparound.
|
HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const float score =
|
const float score =
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
|
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
|
||||||
att[pos] = score;
|
att[pos] = score;
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
|
||||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
|
||||||
const float score =
|
|
||||||
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
|
|
||||||
att[pos_modulo] = score;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -114,25 +106,13 @@ static HWY_INLINE void WeightedSumV(
|
||||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
HWY_DASSERT(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()));
|
||||||
// we supported non-transposed B.
|
// TODO: replace with MatMul(att, v) after it supports non-transposed B.
|
||||||
// TODO: 2..4x unroll
|
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
|
||||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx,
|
worker);
|
||||||
worker);
|
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
|
||||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
{
|
|
||||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
|
||||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx,
|
|
||||||
worker);
|
|
||||||
}
|
|
||||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
|
||||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
|
||||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -146,9 +126,10 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
const size_t seq_len =
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
HWY_DASSERT(last_pos < activations.SeqLen());
|
||||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (query_norm_scale.HasPtr()) {
|
if (query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||||
|
|
@ -163,8 +144,7 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
||||||
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
const Logits logits(att, last_pos + 1);
|
||||||
const Logits logits(att, att_len);
|
|
||||||
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
||||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
|
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
|
||||||
|
|
||||||
|
|
@ -194,8 +174,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||||
|
|
||||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||||
const size_t seq_len =
|
const size_t seq_len = activations.SeqLen();
|
||||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
|
||||||
// All layers should have the same number of heads.
|
// All layers should have the same number of heads.
|
||||||
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
|
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
|
||||||
|
|
||||||
|
|
@ -284,8 +263,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t cache_pos =
|
const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
|
||||||
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
|
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||||
|
|
||||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||||
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
||||||
}
|
}
|
||||||
|
|
@ -304,8 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
|
||||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
|
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||||
layer_idx * cache_layer_size +
|
layer_idx * cache_layer_size +
|
||||||
|
|
@ -325,7 +307,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
||||||
pos, /*mul=*/1.0f);
|
cache_pos, /*mul=*/1.0f);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -716,7 +716,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
size_t last = pos;
|
size_t last = pos;
|
||||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||||
if (prefix_end > 0 && prefix_end - 1 > last) {
|
if (prefix_end > 0 && prefix_end - 1 > last) {
|
||||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
// last_pos in `TileFlashAttention` is inclusive.
|
||||||
last = prefix_end - 1;
|
last = prefix_end - 1;
|
||||||
}
|
}
|
||||||
last_pos[offset] = last;
|
last_pos[offset] = last;
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,9 @@ namespace gcpp {
|
||||||
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
|
// For hwy::BitSet4096. Note that KVs are extremely large for such batches.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
||||||
|
|
||||||
|
// Multiplier so a u64 occupies an entire cache line; avoids false sharing.
|
||||||
|
HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t);
|
||||||
|
|
||||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||||
|
|
||||||
static inline const char* ToString(Tristate t) {
|
static inline const char* ToString(Tristate t) {
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
|
||||||
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
|
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
|
||||||
num_workers * 5, num_workers * 20};
|
num_workers * 5, num_workers * 20};
|
||||||
|
|
||||||
// Count tasks executed to ensure workers aren't optimized out. One per
|
// Count tasks executed to ensure workers aren't optimized out.
|
||||||
// cache line to avoid false sharing.
|
std::vector<uint64_t> counters(num_workers * kU64PerLine);
|
||||||
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t);
|
uint64_t prev_total = 0; // avoids having to reset counters.
|
||||||
|
|
||||||
std::vector<size_t> counters(num_workers * kSizePerLine);
|
|
||||||
size_t prev_total = 0; // avoids having to reset counters.
|
|
||||||
|
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
for (size_t rep = 0; rep < 500; ++rep) {
|
for (size_t rep = 0; rep < 500; ++rep) {
|
||||||
|
|
@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
|
||||||
pool.Run(begin, end, [&](uint64_t task, size_t thread) {
|
pool.Run(begin, end, [&](uint64_t task, size_t thread) {
|
||||||
HWY_ASSERT(begin <= task && task < end);
|
HWY_ASSERT(begin <= task && task < end);
|
||||||
HWY_ASSERT(thread < num_workers);
|
HWY_ASSERT(thread < num_workers);
|
||||||
counters[thread * kSizePerLine]++;
|
counters[thread * kU64PerLine]++;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Reduce count and ensure it matches the expected number of tasks.
|
// Reduce count and ensure it matches the expected number of tasks.
|
||||||
size_t total = 0;
|
uint64_t total = 0;
|
||||||
for (size_t i = 0; i < num_workers; ++i) {
|
for (size_t i = 0; i < num_workers; ++i) {
|
||||||
total += counters[i * kSizePerLine];
|
total += counters[i * kU64PerLine];
|
||||||
}
|
}
|
||||||
const size_t expected = end - begin;
|
const size_t expected = end - begin;
|
||||||
HWY_ASSERT(total == prev_total + expected);
|
HWY_ASSERT(total == prev_total + expected);
|
||||||
|
|
|
||||||
|
|
@ -202,8 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t);
|
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine];
|
||||||
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];
|
|
||||||
|
|
||||||
std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||||
// Governs duration of test; avoid timeout in debug builds.
|
// Governs duration of test; avoid timeout in debug builds.
|
||||||
|
|
@ -217,7 +216,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
for (size_t reps = 0; reps < 1200; ++reps) {
|
for (size_t reps = 0; reps < 1200; ++reps) {
|
||||||
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
|
pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) {
|
||||||
outputs[thread * kU64PerThread] = base + thread;
|
outputs[thread * kU64PerLine] = base + thread;
|
||||||
});
|
});
|
||||||
hwy::PreventElision(outputs[base]);
|
hwy::PreventElision(outputs[base]);
|
||||||
if (pool.AutoTuneComplete()) break;
|
if (pool.AutoTuneComplete()) break;
|
||||||
|
|
@ -258,7 +257,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
pool.Run(0, pool.NumWorkers(), kCaller,
|
pool.Run(0, pool.NumWorkers(), kCaller,
|
||||||
[&](uint64_t task, size_t thread) {
|
[&](uint64_t task, size_t thread) {
|
||||||
outputs[thread * kU64PerThread] = base + thread;
|
outputs[thread * kU64PerLine] = base + thread;
|
||||||
});
|
});
|
||||||
const uint64_t t1 = hwy::timer::Stop();
|
const uint64_t t1 = hwy::timer::Stop();
|
||||||
times.push_back(t1 - t0);
|
times.push_back(t1 - t0);
|
||||||
|
|
@ -268,7 +267,7 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
pool.Run(0, pool.NumWorkers(), kCaller,
|
pool.Run(0, pool.NumWorkers(), kCaller,
|
||||||
[&](uint64_t task, size_t thread) {
|
[&](uint64_t task, size_t thread) {
|
||||||
outputs[thread * kU64PerThread] = base + thread;
|
outputs[thread * kU64PerLine] = base + thread;
|
||||||
});
|
});
|
||||||
const uint64_t t1 = hwy::timer::Start();
|
const uint64_t t1 = hwy::timer::Start();
|
||||||
times.push_back(t1 - t0);
|
times.push_back(t1 - t0);
|
||||||
|
|
@ -315,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) {
|
||||||
|
|
||||||
// Verify outputs to ensure the measured code is not a no-op.
|
// Verify outputs to ensure the measured code is not a no-op.
|
||||||
for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) {
|
for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) {
|
||||||
HWY_ASSERT(outputs[lp * kU64PerThread] >= 1);
|
HWY_ASSERT(outputs[lp * kU64PerLine] >= 1);
|
||||||
HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers());
|
HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers());
|
||||||
for (size_t i = 1; i < kU64PerThread; ++i) {
|
for (size_t i = 1; i < kU64PerLine; ++i) {
|
||||||
HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0);
|
HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue