diff --git a/gemma/activations.h b/gemma/activations.h index 60c26c2..bebe902 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -73,6 +73,10 @@ struct AttentionActivations { att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), + softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads, + allocator)), + softmax_d( + MatFactory("softmax_d", batch_size, layer_config.heads, allocator)), att_sums( MatFactory("att_sums", batch_size, config.model_dim, allocator)), @@ -108,6 +112,8 @@ struct AttentionActivations { pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); + softmax_max.OverrideRows(batch_size); + softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); // `inv_timescale*` are not batched. @@ -120,6 +126,8 @@ struct AttentionActivations { MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output + MatStorageT softmax_max; // see OnlineSoftmaxState + MatStorageT softmax_d; // see OnlineSoftmaxState // Accumulation of attention outputs over heads MatStorageT att_sums; @@ -145,6 +153,8 @@ struct AttentionActivationsPtrs { pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; + softmax_max = activations.softmax_max; + softmax_d = activations.softmax_d; att_sums = activations.att_sums; inv_timescale = activations.inv_timescale; inv_timescale_global = activations.inv_timescale_global; @@ -157,6 +167,8 @@ struct AttentionActivationsPtrs { pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); + softmax_max.OverrideRows(batch_size); + softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); // `inv_timescale*` are not batched. } @@ -180,6 +192,14 @@ struct AttentionActivationsPtrs { // Attention output computed from att * V, size batch_size x (q_heads * // qkv_dim). MatPtrT att_out; + // The maximum logit value encountered when computing att_out from att, + // size batch_size x q_heads . See OnlineSoftmaxState for details. + // WARNING: Only filled in for AttentionImpl::kOld. + MatPtrT softmax_max; + // The sum of scaled exponentials when computing att_out from att, + // size batch_size x q_heads . See OnlineSoftmaxState for details. + // WARNING: Only filled in for AttentionImpl::kOld. + MatPtrT softmax_d; // Accumulation of attention outputs over heads, size batch_size x // model_dim. MatPtrT att_sums; diff --git a/gemma/attention.cc b/gemma/attention.cc index 854a489..67542ae 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -123,7 +123,8 @@ void SingleDotSoftmaxWeightedSum( float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const MatPtr& query_norm_scale, const size_t layer_idx, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { + float* HWY_RESTRICT att_out, const SMOptions& sm_options, + ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; // --seq_len must be large enough to avoid wraparound. @@ -146,7 +147,7 @@ void SingleDotSoftmaxWeightedSum( // SoftMax with optional SoftCap yields "probabilities" in att. const Logits logits(att, last_pos + 1); MaybeLogitsSoftCap(att_cap, logits, ctx, worker); - Softmax(logits, ctx, worker, /*temperature=*/1.0f); + Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options); WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, ctx, worker); @@ -203,6 +204,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; float* HWY_RESTRICT att_out = activations.att_out.Row(tq_idx) + head * qkv_dim; + SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head, + .d_out = activations.softmax_d.Row(tq_idx) + head}; // Make strided read-only views into the kv cache for // this query and head. @@ -215,7 +218,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, query_norm_scale, layer_idx, activations, att, - att_out, ctx, worker); + att_out, sm_options, ctx, worker); }; { diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 6eac06f..3b41ff3 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -28,6 +28,7 @@ #include #include "ops/matmul.h" +#include "ops/ops.h" #include "util/allocator.h" #include "util/basics.h" // TokenAndProb, RngStream #include "util/mat.h" @@ -1125,9 +1126,25 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( // See below for a specialized version for top-1 sampling. // TODO: support bf16 logits using Decompress2. +// Computes softmax probabilities for the given logits, normalizing in-place. +// The calculation is numerically stable, using the max-subtraction trick to +// compute exp(logits[i] - max(logits)) before normalizing by the sum. +// If temperature is provided and not 1.0, each intermediate exp() result is +// divided by temperature before normalization; however, this division by +// temperature cancels out during the final normalization step, meaning +// temperature currently has no effect on the output probabilities. +// @param logits In-out: on input, contains logits; on output, overwritten with +// probabilities. +// @param ctx Input: threading context for parallelism and profiling. +// @param worker Input: worker thread index. +// @param temperature Input: softmax temperature. +// @param softmax_max_out Optional output: if not null, stores the max logit +// value. +// @param softmax_d_out Optional output: if softmax_max is not null, this must +// not be null and stores the sum of exp(logit - max). static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, - const size_t worker, - float temperature = 1.0f) { + const size_t worker, float temperature = 1.0f, + const SMOptions& sm_options = {}) { GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax); HWY_DASSERT(logits.size() != 0); @@ -1171,6 +1188,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; MulByConst(mul, logits.data(), logits.size()); + if (sm_options.max_out) { + *sm_options.max_out = hn::GetLane(vmax); + *sm_options.d_out = sum_exp; + } } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / diff --git a/ops/ops.h b/ops/ops.h index 03b023b..002cb97 100644 --- a/ops/ops.h +++ b/ops/ops.h @@ -41,6 +41,11 @@ static inline HWY_MAYBE_UNUSED MatStorageT CreateInvTimescale( return inv_timescale; } +struct SMOptions { + float* HWY_RESTRICT max_out = nullptr; + float* HWY_RESTRICT d_out = nullptr; +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 8fa3625..0f83df1 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -346,6 +346,51 @@ void TestAllSoftmax() { hn::ForPartialVectors>()(float()); } +class TestSoftmaxState { + public: + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + if (count == 0) return; // *Softmax would assert + if (misalign_b == 0) return; + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe); + + T* x = px.get() + misalign_a; + T* initial_logits = pe.get() + misalign_a; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + initial_logits[i] = x[i]; + } + + float softmax_max; + float softmax_d; + Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f, + {.max_out = &softmax_max, .d_out = &softmax_d}); + + const float maxval = + *std::max_element(initial_logits, initial_logits + count); + + float sum_exp = 0.0f; + for (size_t i = 0; i < count; ++i) { + sum_exp += std::exp(initial_logits[i] - maxval); + } + + ASSERT_NEAR(softmax_max, maxval, 1e-6); + ASSERT_NEAR(softmax_d, sum_exp, 1e-6); + } +}; + +void TestAllSoftmaxState() { + hn::ForPartialVectors>()(float()); +} + template struct TestCreateDistribution { void operator()(hwy::RandomState& rng) { @@ -769,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);