mirror of https://github.com/google/gemma.cpp.git
parent
64d700cab5
commit
9689fc82f9
|
|
@ -218,7 +218,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
constexpr size_t offset = 0; // placeholder, do not remove
|
||||||
|
SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v,
|
||||||
query_norm_scale, layer_idx, activations, att,
|
query_norm_scale, layer_idx, activations, att,
|
||||||
att_out, sm_options, ctx, worker);
|
att_out, sm_options, ctx, worker);
|
||||||
};
|
};
|
||||||
|
|
@ -313,8 +314,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr size_t offset = 0; // placeholder, do not remove
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
||||||
cache_pos, /*mul=*/1.0f);
|
cache_pos + offset,
|
||||||
|
/*mul=*/1.0f);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,9 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
constexpr size_t offset = 0; // placeholder, do not remove
|
||||||
|
const size_t pos =
|
||||||
|
qbatch.Pos(qi) + batch_idx + offset;
|
||||||
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (query_norm_scale.HasPtr()) {
|
if (query_norm_scale.HasPtr()) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue