mirror of https://github.com/google/gemma.cpp.git
Split out common parts (embedder and transformer block) from Prefill() and Transformer() into separate functions.
PiperOrigin-RevId: 644455520
This commit is contained in:
parent
d7d9d14f0e
commit
0e612d9a20
163
gemma/gemma.cc
163
gemma/gemma.cc
|
|
@ -602,70 +602,86 @@ static void AddFromBatched(size_t num_tokens, const float* other, float* x,
|
|||
|
||||
// Placeholder for internal test3, do not remove
|
||||
|
||||
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
|
||||
HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
||||
const WeightArrayT& weights,
|
||||
Activations<TConfig, kBatchSize>& activations) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim,
|
||||
pos + token_idx);
|
||||
};
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerWeightArrayT, class TConfig>
|
||||
HWY_NOINLINE void TransformerLayer(
|
||||
size_t num_tokens, size_t pos, size_t layer,
|
||||
const LayerWeightArrayT* layer_weights,
|
||||
Activations<TConfig, kBatchSize>& activations, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(
|
||||
num_tokens, layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(num_tokens,
|
||||
layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const WeightArrayT& weights,
|
||||
Activations<TConfig, kBatchSize>& activations,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
const int token = tokens[token_idx];
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim, pos);
|
||||
};
|
||||
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
|
||||
});
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(
|
||||
num_tokens, layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(),
|
||||
kModelDim);
|
||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(
|
||||
num_tokens, layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
} // foreach layer
|
||||
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
|
||||
kv_cache, pool);
|
||||
}
|
||||
|
||||
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
activations.x.data(), TConfig::kModelDim);
|
||||
}
|
||||
|
||||
// Compute the transformer for a batch of input tokens. During generation,
|
||||
|
|
@ -684,57 +700,14 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
|
|||
}
|
||||
}
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int token = tokens[token_idx];
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim,
|
||||
pos + token_idx);
|
||||
};
|
||||
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
size_t layer_of_type =
|
||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_attention_norm_scale.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
} else {
|
||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, kv_cache, pool);
|
||||
}
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(
|
||||
num_tokens, layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(),
|
||||
kModelDim);
|
||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplaceBatched<kBatchSize>(
|
||||
num_tokens, layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
|
||||
kv_cache, pool);
|
||||
if (layers_output) {
|
||||
std::string block_name = "blocks." + std::to_string(layer);
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue