Fixes to activations and tensor params

PiperOrigin-RevId: 824820179
This commit is contained in:
Phil Culliton 2025-10-27 21:19:09 -07:00 committed by Copybara-Service
parent 5a05857deb
commit 267dbe00cb
5 changed files with 14 additions and 10 deletions

View File

@ -203,6 +203,12 @@ struct Activations {
ffw_out.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size);
attention_storage.SetBatchSize(batch_size); attention_storage.SetBatchSize(batch_size);
attention.q = attention_storage.q;
attention.q_T = attention_storage.q_T;
attention.pre_att_rms_out = attention_storage.pre_att_rms_out;
attention.att = attention_storage.att;
attention.att_out = attention_storage.att_out;
attention.att_sums = attention_storage.att_sums;
} }
const LayerConfig& layer_config; const LayerConfig& layer_config;

View File

@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV(
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtrT<float>& query_norm_scale, const size_t layer_idx, const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
} }
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const MatPtrT<float>& query_norm_scale, const MatPtr& query_norm_scale,
AttentionActivationsPtrs& activations, AttentionActivationsPtrs& activations,
QBatch& qbatch, ThreadingContext& ctx) { QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);

View File

@ -38,12 +38,12 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \ const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, \ AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \ QBatch& qbatch, ThreadingContext& ctx); \
\ \

View File

@ -91,7 +91,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding. // Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<float>& q, MatPtrT<float>& q,
const MatPtrT<float>& query_norm_scale, const MatPtr& query_norm_scale,
const size_t layer_idx, const size_t layer_idx,
const AttentionActivationsPtrs& activations, const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) { ThreadingContext& ctx) {
@ -592,8 +592,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the // grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency. // 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism, void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const size_t layer_idx, const MatPtr& query_norm_scale,
const MatPtrT<float>& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch, AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);

View File

@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \ namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \ void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \ size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\ \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \
@ -45,8 +45,7 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \
\ \
void FlashAttention(size_t num_tokens, size_t target_parallelism, \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, \ size_t layer_idx, const MatPtr& query_norm_scale, \
const MatPtrT<float>& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \