mirror of https://github.com/google/gemma.cpp.git
Also update attention.h to type-erased query_norm_scale
PiperOrigin-RevId: 825014334
This commit is contained in:
parent
3cc0139ebb
commit
4bd465ffd3
|
|
@ -130,7 +130,7 @@ static HWY_INLINE void WeightedSumV(
|
||||||
void SingleDotSoftmaxWeightedSum(
|
void SingleDotSoftmaxWeightedSum(
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||||
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
|
const MatPtr& query_norm_scale, const size_t layer_idx,
|
||||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
|
|
@ -169,7 +169,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const MatPtrT<float>& query_norm_scale,
|
const MatPtr& query_norm_scale,
|
||||||
AttentionActivationsPtrs& activations,
|
AttentionActivationsPtrs& activations,
|
||||||
QBatch& qbatch, ThreadingContext& ctx) {
|
QBatch& qbatch, ThreadingContext& ctx) {
|
||||||
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,12 @@ namespace gcpp {
|
||||||
void SingleDotSoftmaxWeightedSum( \
|
void SingleDotSoftmaxWeightedSum( \
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||||
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
|
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
||||||
\
|
\
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||||
const MatPtrT<float>& query_norm_scale, \
|
const MatPtr& query_norm_scale, \
|
||||||
AttentionActivationsPtrs& activations, \
|
AttentionActivationsPtrs& activations, \
|
||||||
QBatch& qbatch, ThreadingContext& ctx); \
|
QBatch& qbatch, ThreadingContext& ctx); \
|
||||||
\
|
\
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue