mirror of https://github.com/google/gemma.cpp.git
parent
64d700cab5
commit
9689fc82f9
|
|
@ -218,7 +218,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v,
|
||||
query_norm_scale, layer_idx, activations, att,
|
||||
att_out, sm_options, ctx, worker);
|
||||
};
|
||||
|
|
@ -313,8 +314,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
});
|
||||
}
|
||||
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
||||
cache_pos, /*mul=*/1.0f);
|
||||
cache_pos + offset,
|
||||
/*mul=*/1.0f);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -111,7 +111,9 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
constexpr size_t offset = 0; // placeholder, do not remove
|
||||
const size_t pos =
|
||||
qbatch.Pos(qi) + batch_idx + offset;
|
||||
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||
// Apply rope and scaling to Q.
|
||||
if (query_norm_scale.HasPtr()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue