mirror of https://github.com/google/gemma.cpp.git
Extends Transformer() to prepare for batched processing.
PiperOrigin-RevId: 642603025
This commit is contained in:
parent
2a0e6ee976
commit
1ac9857014
147
gemma/gemma.cc
147
gemma/gemma.cc
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
|
#include <cstdio>
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
|
@ -592,16 +593,15 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
const int token = tokens[token_idx];
|
const int token = tokens[token_idx];
|
||||||
HWY_ASSERT(token >= 0);
|
HWY_DASSERT(token >= 0);
|
||||||
HWY_ASSERT(token < TConfig::kVocabSize);
|
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
if constexpr (TConfig::kAbsolutePE) {
|
if constexpr (TConfig::kAbsolutePE) {
|
||||||
AddAbsolutePositionalEmbeddings(
|
AddAbsolutePositionalEmbeddings(
|
||||||
activations.x.data() + token_idx * kModelDim, TConfig::kModelDim,
|
activations.x.data() + token_idx * kModelDim, kModelDim, pos);
|
||||||
pos);
|
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -646,72 +646,92 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// n = 1 specialization
|
// Compute the transformer for a batch of input tokens. During generation,
|
||||||
template <typename WeightArrayT, class TConfig>
|
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
||||||
HWY_NOINLINE void Transformer(int token, size_t pos,
|
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
|
||||||
|
HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
|
||||||
const WeightArrayT& weights,
|
const WeightArrayT& weights,
|
||||||
Activations<TConfig, 1>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
|
HWY_ASSERT(num_tokens <= kBatchSize);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
float token_f = token;
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
(*layers_output)(pos, "Tokens", &token_f, 1);
|
float token_f = tokens[token_idx];
|
||||||
|
(*layers_output)(pos + token_idx, "Tokens", &token_f, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
|
||||||
activations.x.data(), kModelDim);
|
|
||||||
|
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
if constexpr (TConfig::kAbsolutePE) {
|
const int token = tokens[token_idx];
|
||||||
AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim,
|
HWY_DASSERT(token >= 0);
|
||||||
pos);
|
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||||
};
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
|
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||||
|
kModelDim);
|
||||||
|
if constexpr (TConfig::kAbsolutePE) {
|
||||||
|
AddAbsolutePositionalEmbeddings(
|
||||||
|
activations.x.data() + token_idx * kModelDim, kModelDim,
|
||||||
|
pos + token_idx);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
RMSNormBatched<1>(1, activations.x.data(),
|
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
|
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||||
pool);
|
layer_weights, kv_cache, pool);
|
||||||
} else {
|
} else {
|
||||||
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
|
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||||
kv_cache, pool);
|
layer_weights, kv_cache, pool);
|
||||||
}
|
}
|
||||||
if (TConfig::kPostNormScale) {
|
if (TConfig::kPostNormScale) {
|
||||||
RMSNormInplaceBatched<1>(1,
|
RMSNormInplaceBatched<kBatchSize>(
|
||||||
layer_weights->post_attention_norm_scale.data(),
|
num_tokens, layer_weights->post_attention_norm_scale.data(),
|
||||||
activations.att_post2.data(), kModelDim);
|
activations.att_post2.data(), kModelDim);
|
||||||
}
|
}
|
||||||
AddFromBatched<1>(1, activations.att_post2.data(), activations.x.data(),
|
AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
|
||||||
kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
RMSNormBatched<1>(1, activations.x.data(),
|
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
activations.bf_pre_ffw_rms_out.data(),
|
||||||
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
kModelDim);
|
||||||
|
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||||
if (TConfig::kPostNormScale) {
|
if (TConfig::kPostNormScale) {
|
||||||
RMSNormInplaceBatched<1>(1, layer_weights->post_ffw_norm_scale.data(),
|
RMSNormInplaceBatched<kBatchSize>(
|
||||||
activations.ffw_out.data(), kModelDim);
|
num_tokens, layer_weights->post_ffw_norm_scale.data(),
|
||||||
|
activations.ffw_out.data(), kModelDim);
|
||||||
}
|
}
|
||||||
AddFromBatched<1>(1, activations.ffw_out.data(), activations.x.data(),
|
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||||
kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
std::string block_name = "blocks." + std::to_string(layer);
|
std::string block_name = "blocks." + std::to_string(layer);
|
||||||
(*layers_output)(pos, block_name, activations.x.data(), kModelDim);
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
(*layers_output)(pos + token_idx, block_name,
|
||||||
|
activations.x.data() + token_idx * kModelDim,
|
||||||
|
kModelDim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Placeholder for internal test4, do not remove
|
// Placeholder for internal test4, do not remove
|
||||||
|
|
||||||
RMSNormInplaceBatched<1>(1, weights.final_norm_scale.data(),
|
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
(*layers_output)(pos, "final_norm", activations.x.data(), kModelDim);
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
(*layers_output)(pos + token_idx, "final_norm",
|
||||||
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -781,6 +801,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
}
|
}
|
||||||
HWY_ASSERT(prompt_size > 0);
|
HWY_ASSERT(prompt_size > 0);
|
||||||
|
|
||||||
|
// If no sample_func is provided, we use top-k sampling.
|
||||||
const SampleFunc sample_token =
|
const SampleFunc sample_token =
|
||||||
runtime_config.sample_func
|
runtime_config.sample_func
|
||||||
? runtime_config.sample_func
|
? runtime_config.sample_func
|
||||||
|
|
@ -827,36 +848,35 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start generation.
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
|
|
||||||
HWY_DASSERT(pos_offset == prompt_size - 1);
|
HWY_DASSERT(pos_offset == prompt_size - 1);
|
||||||
|
|
||||||
size_t pos_gen_start = pos_offset;
|
size_t pos_gen_start = pos_offset;
|
||||||
int token = prompt.at(pos_offset);
|
int token = prompt.at(pos_offset);
|
||||||
runtime_config.stream_token(token, 0);
|
// The loop below is not yet prepared for batch size > 1.
|
||||||
|
HWY_ASSERT(kDecodeBatchSize == 1);
|
||||||
|
if (!runtime_config.stream_token(token, 0.0f)) return;
|
||||||
for (size_t generate_pos = 0;
|
for (size_t generate_pos = 0;
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
|
||||||
Transformer(token, pos, weights, activations, kv_cache, pool,
|
activations, kv_cache, pool, layers_output);
|
||||||
layers_output);
|
float token_logit = 0.0f;
|
||||||
float* final_activation = activations.x.data();
|
|
||||||
// The condition below is always true if we are doing Prefill above.
|
// The condition below is always true if we are doing Prefill above.
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
// is disabled.
|
// is disabled.
|
||||||
|
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
||||||
if (is_generating_phase) {
|
if (is_generating_phase) {
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Generation phase
|
// Compute logits from last layer activations.
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||||
weights.embedder_input_embedding, 0, final_activation,
|
weights.embedder_input_embedding, 0, activations.x.data(),
|
||||||
activations.even_odd.data(), activations.logits.data(), pool);
|
activations.even_odd.data(), activations.logits.data(), pool);
|
||||||
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
|
||||||
// Barrier: must have all logits so we can subtract max.
|
// Barrier: must have all logits so we can subtract max.
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
token = sample_token(activations.logits.data(), kVocabSize);
|
token = sample_token(activations.logits.data(), kVocabSize);
|
||||||
if (!runtime_config.stream_token(token, activations.logits[token])) {
|
token_logit = activations.logits[token];
|
||||||
token = runtime_config.eos_id;
|
|
||||||
}
|
|
||||||
if (generate_pos == 0) {
|
if (generate_pos == 0) {
|
||||||
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
|
||||||
}
|
}
|
||||||
|
|
@ -864,20 +884,19 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
// We would take this branch if we were not doing Prefill but would
|
// We would take this branch if we were not doing Prefill but would
|
||||||
// process the tokens of the prompt one at a time.
|
// process the tokens of the prompt one at a time.
|
||||||
token = prompt.at(pos_offset + 1);
|
token = prompt.at(pos_offset + 1);
|
||||||
if (!runtime_config.stream_token(token, 0)) {
|
}
|
||||||
token = runtime_config.eos_id;
|
if (!runtime_config.stream_token(token, token_logit)) {
|
||||||
}
|
token = runtime_config.eos_id;
|
||||||
}
|
}
|
||||||
if (token == runtime_config.eos_id) {
|
if (token == runtime_config.eos_id) {
|
||||||
if (runtime_config.verbosity >= 2) {
|
|
||||||
const double gen_end = hwy::platform::Now();
|
|
||||||
timing_info.gen_tok_sec =
|
|
||||||
static_cast<double>(pos_offset - pos_gen_start) /
|
|
||||||
(gen_end - gen_start);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (runtime_config.verbosity >= 2) {
|
||||||
|
const double gen_end = hwy::platform::Now();
|
||||||
|
timing_info.gen_tok_sec =
|
||||||
|
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
|
|
@ -898,7 +917,7 @@ struct AllocatePrefill {
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct AllocateDecode {
|
struct AllocateDecode {
|
||||||
ByteStorageT operator()() const {
|
ByteStorageT operator()() const {
|
||||||
return AllocateSizeof<Activations<TConfig, 1>>();
|
return AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
constexpr size_t kPrefillBatchSize = 16;
|
constexpr size_t kPrefillBatchSize = 16;
|
||||||
|
constexpr size_t kDecodeBatchSize = 1;
|
||||||
constexpr bool kSystemPrompt = false;
|
constexpr bool kSystemPrompt = false;
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue