mirror of https://github.com/google/gemma.cpp.git
Back to f32 kv_cache, but via typedef
PiperOrigin-RevId: 785422614
This commit is contained in:
parent
56c9196eb6
commit
5474146129
|
|
@ -52,7 +52,7 @@ namespace HWY_NAMESPACE {
|
|||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT q,
|
||||
const MatPtrT<BF16>& k, float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
|
|
@ -100,7 +100,7 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
static HWY_INLINE void WeightedSumV(
|
||||
const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||
const MatPtrT<BF16>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||
// we supported non-transposed B.
|
||||
|
|
@ -125,7 +125,7 @@ static HWY_INLINE void WeightedSumV(
|
|||
// in place for RMSNorm.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
|
|
@ -218,9 +218,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
// this query and head.
|
||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
|
||||
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
||||
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
|
|
@ -259,7 +259,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// Set up MatMul row pointers for writing to KV, which consists of
|
||||
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
||||
// because rows are computed modulo seq_len.
|
||||
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
||||
MatPtrT<KV_t> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
||||
layer.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
|
|
@ -287,7 +287,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ namespace gcpp {
|
|||
namespace NAMESPACE { \
|
||||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out, size_t worker); \
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
using KV_t = float;
|
||||
|
||||
struct KVCache {
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
|
||||
|
||||
|
|
@ -42,7 +44,7 @@ struct KVCache {
|
|||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||
|
||||
MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
private:
|
||||
// For use by other ctor and Copy()
|
||||
|
|
|
|||
Loading…
Reference in New Issue