mirror of https://github.com/google/gemma.cpp.git
Added access to softmax attention internals to regular attention
PiperOrigin-RevId: 833353546
This commit is contained in:
parent
5a500872b8
commit
210ebab346
|
|
@ -73,6 +73,10 @@ struct AttentionActivations {
|
||||||
att_out(MatFactory("att_out", batch_size,
|
att_out(MatFactory("att_out", batch_size,
|
||||||
layer_config.heads * layer_config.qkv_dim,
|
layer_config.heads * layer_config.qkv_dim,
|
||||||
allocator)),
|
allocator)),
|
||||||
|
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
|
||||||
|
allocator)),
|
||||||
|
softmax_d(
|
||||||
|
MatFactory("softmax_d", batch_size, layer_config.heads, allocator)),
|
||||||
att_sums(
|
att_sums(
|
||||||
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
|
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
|
||||||
|
|
||||||
|
|
@ -108,6 +112,8 @@ struct AttentionActivations {
|
||||||
pre_att_rms_out.OverrideRows(batch_size);
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
att.OverrideRows(batch_size);
|
att.OverrideRows(batch_size);
|
||||||
att_out.OverrideRows(batch_size);
|
att_out.OverrideRows(batch_size);
|
||||||
|
softmax_max.OverrideRows(batch_size);
|
||||||
|
softmax_d.OverrideRows(batch_size);
|
||||||
att_sums.OverrideRows(batch_size);
|
att_sums.OverrideRows(batch_size);
|
||||||
|
|
||||||
// `inv_timescale*` are not batched.
|
// `inv_timescale*` are not batched.
|
||||||
|
|
@ -120,6 +126,8 @@ struct AttentionActivations {
|
||||||
MatStorageT<float> pre_att_rms_out;
|
MatStorageT<float> pre_att_rms_out;
|
||||||
MatStorageT<float> att; // attention vector
|
MatStorageT<float> att; // attention vector
|
||||||
MatStorageT<float> att_out; // attention output
|
MatStorageT<float> att_out; // attention output
|
||||||
|
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
|
||||||
|
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
|
||||||
// Accumulation of attention outputs over heads
|
// Accumulation of attention outputs over heads
|
||||||
MatStorageT<BF16> att_sums;
|
MatStorageT<BF16> att_sums;
|
||||||
|
|
||||||
|
|
@ -145,6 +153,8 @@ struct AttentionActivationsPtrs {
|
||||||
pre_att_rms_out = activations.pre_att_rms_out;
|
pre_att_rms_out = activations.pre_att_rms_out;
|
||||||
att = activations.att;
|
att = activations.att;
|
||||||
att_out = activations.att_out;
|
att_out = activations.att_out;
|
||||||
|
softmax_max = activations.softmax_max;
|
||||||
|
softmax_d = activations.softmax_d;
|
||||||
att_sums = activations.att_sums;
|
att_sums = activations.att_sums;
|
||||||
inv_timescale = activations.inv_timescale;
|
inv_timescale = activations.inv_timescale;
|
||||||
inv_timescale_global = activations.inv_timescale_global;
|
inv_timescale_global = activations.inv_timescale_global;
|
||||||
|
|
@ -157,6 +167,8 @@ struct AttentionActivationsPtrs {
|
||||||
pre_att_rms_out.OverrideRows(batch_size);
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
att.OverrideRows(batch_size);
|
att.OverrideRows(batch_size);
|
||||||
att_out.OverrideRows(batch_size);
|
att_out.OverrideRows(batch_size);
|
||||||
|
softmax_max.OverrideRows(batch_size);
|
||||||
|
softmax_d.OverrideRows(batch_size);
|
||||||
att_sums.OverrideRows(batch_size);
|
att_sums.OverrideRows(batch_size);
|
||||||
// `inv_timescale*` are not batched.
|
// `inv_timescale*` are not batched.
|
||||||
}
|
}
|
||||||
|
|
@ -180,6 +192,14 @@ struct AttentionActivationsPtrs {
|
||||||
// Attention output computed from att * V, size batch_size x (q_heads *
|
// Attention output computed from att * V, size batch_size x (q_heads *
|
||||||
// qkv_dim).
|
// qkv_dim).
|
||||||
MatPtrT<float> att_out;
|
MatPtrT<float> att_out;
|
||||||
|
// The maximum logit value encountered when computing att_out from att,
|
||||||
|
// size batch_size x q_heads . See OnlineSoftmaxState for details.
|
||||||
|
// WARNING: Only filled in for AttentionImpl::kOld.
|
||||||
|
MatPtrT<float> softmax_max;
|
||||||
|
// The sum of scaled exponentials when computing att_out from att,
|
||||||
|
// size batch_size x q_heads . See OnlineSoftmaxState for details.
|
||||||
|
// WARNING: Only filled in for AttentionImpl::kOld.
|
||||||
|
MatPtrT<float> softmax_d;
|
||||||
// Accumulation of attention outputs over heads, size batch_size x
|
// Accumulation of attention outputs over heads, size batch_size x
|
||||||
// model_dim.
|
// model_dim.
|
||||||
MatPtrT<BF16> att_sums;
|
MatPtrT<BF16> att_sums;
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,8 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||||
const MatPtr& query_norm_scale, const size_t layer_idx,
|
const MatPtr& query_norm_scale, const size_t layer_idx,
|
||||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
float* HWY_RESTRICT att_out, const SMOptions& sm_options,
|
||||||
|
ThreadingContext& ctx, const size_t worker) {
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
// --seq_len must be large enough to avoid wraparound.
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
|
|
@ -146,7 +147,7 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const Logits logits(att, last_pos + 1);
|
const Logits logits(att, last_pos + 1);
|
||||||
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
MaybeLogitsSoftCap(att_cap, logits, ctx, worker);
|
||||||
Softmax(logits, ctx, worker, /*temperature=*/1.0f);
|
Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options);
|
||||||
|
|
||||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||||
ctx, worker);
|
ctx, worker);
|
||||||
|
|
@ -203,6 +204,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
|
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
||||||
|
SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head,
|
||||||
|
.d_out = activations.softmax_d.Row(tq_idx) + head};
|
||||||
|
|
||||||
// Make strided read-only views into the kv cache for
|
// Make strided read-only views into the kv cache for
|
||||||
// this query and head.
|
// this query and head.
|
||||||
|
|
@ -215,7 +218,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||||
query_norm_scale, layer_idx, activations, att,
|
query_norm_scale, layer_idx, activations, att,
|
||||||
att_out, ctx, worker);
|
att_out, sm_options, ctx, worker);
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
|
#include "ops/ops.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h" // TokenAndProb, RngStream
|
#include "util/basics.h" // TokenAndProb, RngStream
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
|
@ -1125,9 +1126,25 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
// TODO: support bf16 logits using Decompress2.
|
// TODO: support bf16 logits using Decompress2.
|
||||||
|
// Computes softmax probabilities for the given logits, normalizing in-place.
|
||||||
|
// The calculation is numerically stable, using the max-subtraction trick to
|
||||||
|
// compute exp(logits[i] - max(logits)) before normalizing by the sum.
|
||||||
|
// If temperature is provided and not 1.0, each intermediate exp() result is
|
||||||
|
// divided by temperature before normalization; however, this division by
|
||||||
|
// temperature cancels out during the final normalization step, meaning
|
||||||
|
// temperature currently has no effect on the output probabilities.
|
||||||
|
// @param logits In-out: on input, contains logits; on output, overwritten with
|
||||||
|
// probabilities.
|
||||||
|
// @param ctx Input: threading context for parallelism and profiling.
|
||||||
|
// @param worker Input: worker thread index.
|
||||||
|
// @param temperature Input: softmax temperature.
|
||||||
|
// @param softmax_max_out Optional output: if not null, stores the max logit
|
||||||
|
// value.
|
||||||
|
// @param softmax_d_out Optional output: if softmax_max is not null, this must
|
||||||
|
// not be null and stores the sum of exp(logit - max).
|
||||||
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
||||||
const size_t worker,
|
const size_t worker, float temperature = 1.0f,
|
||||||
float temperature = 1.0f) {
|
const SMOptions& sm_options = {}) {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
|
GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax);
|
||||||
HWY_DASSERT(logits.size() != 0);
|
HWY_DASSERT(logits.size() != 0);
|
||||||
|
|
||||||
|
|
@ -1171,6 +1188,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
|
||||||
// Double-precision reciprocal does not appear to affect the results.
|
// Double-precision reciprocal does not appear to affect the results.
|
||||||
const float mul = 1.0f / sum_exp;
|
const float mul = 1.0f / sum_exp;
|
||||||
MulByConst(mul, logits.data(), logits.size());
|
MulByConst(mul, logits.data(), logits.size());
|
||||||
|
if (sm_options.max_out) {
|
||||||
|
*sm_options.max_out = hn::GetLane(vmax);
|
||||||
|
*sm_options.d_out = sum_exp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,11 @@ static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
|
||||||
return inv_timescale;
|
return inv_timescale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SMOptions {
|
||||||
|
float* HWY_RESTRICT max_out = nullptr;
|
||||||
|
float* HWY_RESTRICT d_out = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_
|
||||||
|
|
|
||||||
|
|
@ -346,6 +346,51 @@ void TestAllSoftmax() {
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TestSoftmaxState {
|
||||||
|
public:
|
||||||
|
template <class D>
|
||||||
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
hwy::RandomState& rng) {
|
||||||
|
if (count == 0) return; // *Softmax would assert
|
||||||
|
if (misalign_b == 0) return;
|
||||||
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
HWY_ASSERT(px && pe);
|
||||||
|
|
||||||
|
T* x = px.get() + misalign_a;
|
||||||
|
T* initial_logits = pe.get() + misalign_a;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
x[i] = Random<T>(rng);
|
||||||
|
initial_logits[i] = x[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float softmax_max;
|
||||||
|
float softmax_d;
|
||||||
|
Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f,
|
||||||
|
{.max_out = &softmax_max, .d_out = &softmax_d});
|
||||||
|
|
||||||
|
const float maxval =
|
||||||
|
*std::max_element(initial_logits, initial_logits + count);
|
||||||
|
|
||||||
|
float sum_exp = 0.0f;
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
sum_exp += std::exp(initial_logits[i] - maxval);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_NEAR(softmax_max, maxval, 1e-6);
|
||||||
|
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestAllSoftmaxState() {
|
||||||
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmaxState>>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t k>
|
template <size_t k>
|
||||||
struct TestCreateDistribution {
|
struct TestCreateDistribution {
|
||||||
void operator()(hwy::RandomState& rng) {
|
void operator()(hwy::RandomState& rng) {
|
||||||
|
|
@ -769,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue