Back to f32 kv_cache, but via typedef

PiperOrigin-RevId: 785422614
This commit is contained in:
Jan Wassenberg 2025-07-21 07:04:55 -07:00 committed by Copybara-Service
parent 56c9196eb6
commit 5474146129
3 changed files with 11 additions and 9 deletions

View File

@ -52,7 +52,7 @@ namespace HWY_NAMESPACE {
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
@ -100,7 +100,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B.
@ -125,7 +125,7 @@ static HWY_INLINE void WeightedSumV(
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, const size_t worker) {
@ -218,9 +218,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// this query and head.
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
@ -259,7 +259,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len.
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
MatPtrT<KV_t> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
@ -287,7 +287,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = qbatch.KV(qi).kv_cache;
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;

View File

@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, size_t worker); \

View File

@ -25,6 +25,8 @@
namespace gcpp {
using KV_t = float;
struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
@ -42,7 +44,7 @@ struct KVCache {
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private:
// For use by other ctor and Copy()