Added access to softmax attention internals to regular attention

PiperOrigin-RevId: 833353546
This commit is contained in:
Martin Stolle 2025-11-17 08:35:09 -08:00 committed by Copybara-Service
parent 5a500872b8
commit 210ebab346
5 changed files with 100 additions and 5 deletions

View File

@ -73,6 +73,10 @@ struct AttentionActivations {
att_out(MatFactory("att_out", batch_size, att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim, layer_config.heads * layer_config.qkv_dim,
allocator)), allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
MatFactory("softmax_d", batch_size, layer_config.heads, allocator)),
att_sums( att_sums(
MatFactory("att_sums", batch_size, config.model_dim, allocator)), MatFactory("att_sums", batch_size, config.model_dim, allocator)),
@ -108,6 +112,8 @@ struct AttentionActivations {
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size); att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size); att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched. // `inv_timescale*` are not batched.
@ -120,6 +126,8 @@ struct AttentionActivations {
MatStorageT<float> pre_att_rms_out; MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output MatStorageT<float> att_out; // attention output
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads // Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums; MatStorageT<BF16> att_sums;
@ -145,6 +153,8 @@ struct AttentionActivationsPtrs {
pre_att_rms_out = activations.pre_att_rms_out; pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att; att = activations.att;
att_out = activations.att_out; att_out = activations.att_out;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums; att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale; inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global; inv_timescale_global = activations.inv_timescale_global;
@ -157,6 +167,8 @@ struct AttentionActivationsPtrs {
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size); att_out.OverrideRows(batch_size);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size); att_sums.OverrideRows(batch_size);
// `inv_timescale*` are not batched. // `inv_timescale*` are not batched.
} }
@ -180,6 +192,14 @@ struct AttentionActivationsPtrs {
// Attention output computed from att * V, size batch_size x (q_heads * // Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim). // qkv_dim).
MatPtrT<float> att_out; MatPtrT<float> att_out;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
MatPtrT<float> softmax_max;
// The sum of scaled exponentials when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
MatPtrT<float> softmax_d;
// Accumulation of attention outputs over heads, size batch_size x // Accumulation of attention outputs over heads, size batch_size x
// model_dim. // model_dim.
MatPtrT<BF16> att_sums; MatPtrT<BF16> att_sums;

View File

@ -123,7 +123,8 @@ void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtr& query_norm_scale, const size_t layer_idx, const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { float* HWY_RESTRICT att_out, const SMOptions& sm_options,
ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
// --seq_len must be large enough to avoid wraparound. // --seq_len must be large enough to avoid wraparound.
@ -146,7 +147,7 @@ void SingleDotSoftmaxWeightedSum(
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const Logits logits(att, last_pos + 1); const Logits logits(att, last_pos + 1);
MaybeLogitsSoftCap(att_cap, logits, ctx, worker); MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
Softmax(logits, ctx, worker, /*temperature=*/1.0f); Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
ctx, worker); ctx, worker);
@ -203,6 +204,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations.att_out.Row(tq_idx) + head * qkv_dim; activations.att_out.Row(tq_idx) + head * qkv_dim;
SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head,
.d_out = activations.softmax_d.Row(tq_idx) + head};
// Make strided read-only views into the kv cache for // Make strided read-only views into the kv cache for
// this query and head. // this query and head.
@ -215,7 +218,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
query_norm_scale, layer_idx, activations, att, query_norm_scale, layer_idx, activations, att,
att_out, ctx, worker); att_out, sm_options, ctx, worker);
}; };
{ {

View File

@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include "ops/matmul.h" #include "ops/matmul.h"
#include "ops/ops.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" // TokenAndProb, RngStream #include "util/basics.h" // TokenAndProb, RngStream
#include "util/mat.h" #include "util/mat.h"
@ -1125,9 +1126,25 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
// TODO: support bf16 logits using Decompress2. // TODO: support bf16 logits using Decompress2.
// Computes softmax probabilities for the given logits, normalizing in-place.
// The calculation is numerically stable, using the max-subtraction trick to
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
// If temperature is provided and not 1.0, each intermediate exp() result is
// divided by temperature before normalization; however, this division by
// temperature cancels out during the final normalization step, meaning
// temperature currently has no effect on the output probabilities.
// @param logits In-out: on input, contains logits; on output, overwritten with
// probabilities.
// @param ctx Input: threading context for parallelism and profiling.
// @param worker Input: worker thread index.
// @param temperature Input: softmax temperature.
// @param softmax_max_out Optional output: if not null, stores the max logit
// value.
// @param softmax_d_out Optional output: if softmax_max is not null, this must
// not be null and stores the sum of exp(logit - max).
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
const size_t worker, const size_t worker, float temperature = 1.0f,
float temperature = 1.0f) { const SMOptions& sm_options = {}) {
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax); GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
HWY_DASSERT(logits.size() != 0); HWY_DASSERT(logits.size() != 0);
@ -1171,6 +1188,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, logits.data(), logits.size()); MulByConst(mul, logits.data(), logits.size());
if (sm_options.max_out) {
*sm_options.max_out = hn::GetLane(vmax);
*sm_options.d_out = sum_exp;
}
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /

View File

@ -41,6 +41,11 @@ static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
return inv_timescale; return inv_timescale;
} }
struct SMOptions {
float* HWY_RESTRICT max_out = nullptr;
float* HWY_RESTRICT d_out = nullptr;
};
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_

View File

@ -346,6 +346,51 @@ void TestAllSoftmax() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float()); hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
} }
class TestSoftmaxState {
public:
template <class D>
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
hwy::RandomState& rng) {
if (count == 0) return; // *Softmax would assert
if (misalign_b == 0) return;
using T = hn::TFromD<D>;
hwy::AlignedFreeUniquePtr<T[]> px =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
hwy::AlignedFreeUniquePtr<T[]> pe =
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
HWY_ASSERT(px && pe);
T* x = px.get() + misalign_a;
T* initial_logits = pe.get() + misalign_a;
for (size_t i = 0; i < count; ++i) {
x[i] = Random<T>(rng);
initial_logits[i] = x[i];
}
float softmax_max;
float softmax_d;
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f,
{.max_out = &softmax_max, .d_out = &softmax_d});
const float maxval =
*std::max_element(initial_logits, initial_logits + count);
float sum_exp = 0.0f;
for (size_t i = 0; i < count; ++i) {
sum_exp += std::exp(initial_logits[i] - maxval);
}
ASSERT_NEAR(softmax_max, maxval, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
}
};
void TestAllSoftmaxState() {
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
}
template <size_t k> template <size_t k>
struct TestCreateDistribution { struct TestCreateDistribution {
void operator()(hwy::RandomState& rng) { void operator()(hwy::RandomState& rng) {
@ -769,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);