mirror of https://github.com/google/gemma.cpp.git
Code cleanup
- Simplify template arg list, enable deduction - missing hn:: on " Lanes" - 1.0f suffix - move RMSNormBatched into ops.h - static constexpr -> constexpr - concrete type instead of LayerT, WeightArrayT - inline GetWeights - remove if (runtime_config.verbosity - merge AllocatePrefill and AllocateDecode - remove bf_ffw_hidden PiperOrigin-RevId: 644931277
This commit is contained in:
parent
658fb3e506
commit
48ebba8b7a
175
gemma/gemma.cc
175
gemma/gemma.cc
|
|
@ -78,6 +78,7 @@ struct Activations {
|
||||||
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> x; // input
|
std::array<float, kBatchSize * kModelDim> x; // input
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
|
||||||
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
|
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
|
||||||
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
|
||||||
|
|
@ -87,17 +88,13 @@ struct Activations {
|
||||||
att_post1; // attention output after linear transformation, per head
|
att_post1; // attention output after linear transformation, per head
|
||||||
std::array<float, kBatchSize * kModelDim>
|
std::array<float, kBatchSize * kModelDim>
|
||||||
att_post2; // accumulation of attention outputs over heads
|
att_post2; // accumulation of attention outputs over heads
|
||||||
|
|
||||||
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
|
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
|
||||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
|
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
|
||||||
|
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1; // MatMul output
|
||||||
// For FFW MatMul.
|
|
||||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
|
|
||||||
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
|
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
|
||||||
|
|
||||||
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
|
|
||||||
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
|
|
||||||
// bf_ffw_hidden;
|
|
||||||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
std::array<float, kBatchSize * kModelDim> ffw_out;
|
||||||
|
|
||||||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
||||||
|
|
||||||
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
|
||||||
|
|
@ -234,19 +231,19 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
template <class TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void GriffinRecurrent(
|
HWY_NOINLINE void GriffinRecurrent(
|
||||||
size_t batch_start, size_t num_tokens, size_t layer,
|
size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
const CompressedLayer<TConfig>* layer_weights, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
static constexpr size_t kModelDim =
|
constexpr size_t kModelDim = Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
Activations<TConfig, kBatchSize>::kModelDim;
|
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
|
||||||
|
|
||||||
// X / Y linear layers.
|
// X / Y linear layers.
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
|
|
@ -268,7 +265,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||||
HWY_FULL(float) df;
|
HWY_FULL(float) df;
|
||||||
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
HWY_DASSERT(kModelDim % hn::Lanes(df) == 0);
|
||||||
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||||
|
|
||||||
// cache[i] = input at time t-i.
|
// cache[i] = input at time t-i.
|
||||||
|
|
@ -279,7 +276,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
kv_cache.conv1d_cache.get() + layer_offset +
|
kv_cache.conv1d_cache.get() + layer_offset +
|
||||||
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
|
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) {
|
||||||
auto xv = hn::Load(df, x + i);
|
auto xv = hn::Load(df, x + i);
|
||||||
auto accum0 =
|
auto accum0 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||||
|
|
@ -332,15 +329,15 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
fn_mul);
|
fn_mul);
|
||||||
// RNN scan
|
// RNN scan
|
||||||
HWY_FULL(float) df;
|
HWY_FULL(float) df;
|
||||||
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
|
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
|
||||||
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
|
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
|
||||||
auto log_a = hn::Load(df, a + head_offset + i);
|
auto log_a = hn::Load(df, a + head_offset + i);
|
||||||
auto gated_x = hn::Load(df, x + head_offset + i);
|
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||||
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||||
auto a = hn::Exp(df, log_a);
|
auto a = hn::Exp(df, log_a);
|
||||||
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
|
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
|
||||||
if (pos == 0) {
|
if (pos == 0) {
|
||||||
x_multiplier = hn::Set(df, 1.0);
|
x_multiplier = hn::Set(df, 1.0f);
|
||||||
}
|
}
|
||||||
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||||
hn::Store(new_x, df, rnn_state + head_offset + i);
|
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||||
|
|
@ -365,11 +362,11 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
template <class TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
const LayerT* layer_weights, KVCache& kv_cache,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
hwy::ThreadPool& pool) {
|
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Attention");
|
PROFILER_ZONE("Gen.Attention");
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
using TActivations = Activations<TConfig, kBatchSize>;
|
using TActivations = Activations<TConfig, kBatchSize>;
|
||||||
|
|
@ -429,7 +426,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
|
|
||||||
static_assert((kHeads % kKVHeads) == 0,
|
static_assert((kHeads % kKVHeads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
static constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
||||||
pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
|
pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t batch_idx = task / kHeads;
|
const size_t batch_idx = task / kHeads;
|
||||||
|
|
@ -494,13 +491,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, typename TConfig>
|
template <class TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
size_t num_tokens, const LayerT* layer_weights,
|
size_t num_tokens,
|
||||||
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
float* HWY_RESTRICT even_odd = activations.even_odd.data();
|
||||||
|
|
||||||
// TODO: MatMul does not yet support adding another matrix to the result.
|
// TODO: MatMul does not yet support adding another matrix to the result.
|
||||||
|
|
@ -570,42 +568,11 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The below "batched" versions are just simple loops for now.
|
template <class TConfig, size_t kBatchSize>
|
||||||
template <size_t kBatchSize, typename WeightT, typename OutT>
|
|
||||||
static void RMSNormBatched(size_t num_tokens, const float* activations,
|
|
||||||
const WeightT* weights, OutT* out,
|
|
||||||
const size_t model_dim) {
|
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
RMSNorm(activations + token_idx * model_dim, weights,
|
|
||||||
out + token_idx * model_dim, model_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kBatchSize, typename WeightT, typename InOutT>
|
|
||||||
static void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
|
||||||
InOutT* inout, const size_t model_dim) {
|
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kBatchSize>
|
|
||||||
static void AddFromBatched(size_t num_tokens, const float* other, float* x,
|
|
||||||
const size_t model_dim) {
|
|
||||||
HWY_DASSERT(num_tokens <= kBatchSize);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
|
||||||
model_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
|
|
||||||
HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
||||||
const WeightArrayT& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations) {
|
Activations<TConfig, kBatchSize>& activations) {
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
HWY_DASSERT(token >= 0);
|
HWY_DASSERT(token >= 0);
|
||||||
|
|
@ -621,13 +588,13 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerWeightArrayT, class TConfig>
|
template <class TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void TransformerLayer(
|
HWY_NOINLINE void TransformerLayer(
|
||||||
size_t num_tokens, size_t pos, size_t layer,
|
size_t num_tokens, size_t pos, size_t layer,
|
||||||
const LayerWeightArrayT* layer_weights,
|
const CompressedLayer<TConfig>* layer_weights,
|
||||||
Activations<TConfig, kBatchSize>& activations, KVCache& kv_cache,
|
Activations<TConfig, kBatchSize>& activations, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
|
|
@ -635,11 +602,11 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
Attention(pos, num_tokens, layer_of_type, activations, layer_weights,
|
||||||
layer_weights, kv_cache, pool);
|
kv_cache, pool);
|
||||||
} else {
|
} else {
|
||||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
GriffinRecurrent(pos, num_tokens, layer_of_type, activations, layer_weights,
|
||||||
layer_weights, kv_cache, pool);
|
kv_cache, pool);
|
||||||
}
|
}
|
||||||
if (TConfig::kPostNormScale) {
|
if (TConfig::kPostNormScale) {
|
||||||
RMSNormInplaceBatched<kBatchSize>(
|
RMSNormInplaceBatched<kBatchSize>(
|
||||||
|
|
@ -651,7 +618,7 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
FFW(activations, num_tokens, layer_weights, pool);
|
||||||
if (TConfig::kPostNormScale) {
|
if (TConfig::kPostNormScale) {
|
||||||
RMSNormInplaceBatched<kBatchSize>(num_tokens,
|
RMSNormInplaceBatched<kBatchSize>(num_tokens,
|
||||||
layer_weights->post_ffw_norm_scale.data(),
|
layer_weights->post_ffw_norm_scale.data(),
|
||||||
|
|
@ -661,9 +628,9 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
template <class TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
const WeightArrayT& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||||
|
|
@ -685,9 +652,9 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
|
|
||||||
// Compute the transformer for a batch of input tokens. During generation,
|
// Compute the transformer for a batch of input tokens. During generation,
|
||||||
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
// we usually have num_tokens == 1 (and also kBatchSize == 1).
|
||||||
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
|
template <class TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
|
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
const WeightArrayT& weights,
|
const CompressedWeights<TConfig>& weights,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const LayersOutputFunc& layers_output) {
|
const LayersOutputFunc& layers_output) {
|
||||||
|
|
@ -698,17 +665,18 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
layers_output(pos + token_idx, "Tokens", &token_f, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
|
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||||
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
|
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
|
||||||
kv_cache, pool);
|
kv_cache, pool);
|
||||||
|
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
std::string block_name = "blocks." + std::to_string(layer);
|
const std::string block_name = "blocks." + std::to_string(layer);
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
layers_output(pos + token_idx, block_name,
|
layers_output(pos + token_idx, block_name,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
|
|
@ -754,11 +722,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
fprintf(stderr, "%zu\n", prompt_size);
|
fprintf(stderr, "%zu\n", prompt_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
template <class TConfig>
|
HWY_ASSERT(prompt_size > 0);
|
||||||
const CompressedWeights<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
|
|
||||||
return *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
template <class TConfig, size_t kBatchSize>
|
||||||
|
|
@ -776,12 +741,13 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||||
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
const CompressedWeights<TConfig>& weights =
|
||||||
|
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
auto& prefill_activations =
|
auto& prefill_activations =
|
||||||
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
||||||
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
||||||
|
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
size_t prompt_size = prompt.size();
|
size_t prompt_size = prompt.size();
|
||||||
size_t max_tokens = runtime_config.max_tokens;
|
size_t max_tokens = runtime_config.max_tokens;
|
||||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||||
|
|
@ -791,7 +757,6 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
max_tokens);
|
max_tokens);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
HWY_ASSERT(prompt_size > 0);
|
|
||||||
|
|
||||||
// If no sample_func is provided, we use top-k sampling.
|
// If no sample_func is provided, we use top-k sampling.
|
||||||
const SampleFunc sample_token =
|
const SampleFunc sample_token =
|
||||||
|
|
@ -825,8 +790,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
||||||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||||
const int* batch_tokens = prompt.data() + pos_offset;
|
const int* batch_tokens = prompt.data() + pos_offset;
|
||||||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
Prefill(batch_tokens, batch_size, pos, weights, prefill_activations,
|
||||||
prefill_activations, kv_cache, pool);
|
kv_cache, pool);
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
|
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
|
||||||
}
|
}
|
||||||
|
|
@ -834,11 +799,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
pos_offset += batch_size;
|
pos_offset += batch_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (runtime_config.verbosity >= 2) {
|
|
||||||
const double prefill_end = hwy::platform::Now();
|
|
||||||
timing_info.prefill_tok_sec =
|
timing_info.prefill_tok_sec =
|
||||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
static_cast<double>(pos_offset) / (hwy::platform::Now() - prefill_start);
|
||||||
}
|
|
||||||
|
|
||||||
// Start generation.
|
// Start generation.
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
|
|
@ -851,9 +813,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
for (size_t generate_pos = 0;
|
for (size_t generate_pos = 0;
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
|
Transformer(&token, kDecodeBatchSize, pos, weights, activations, kv_cache,
|
||||||
activations, kv_cache, pool,
|
pool, runtime_config.layers_output);
|
||||||
runtime_config.layers_output);
|
|
||||||
float token_logit = 0.0f;
|
float token_logit = 0.0f;
|
||||||
// The condition below is always true if we are doing Prefill above.
|
// The condition below is always true if we are doing Prefill above.
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
|
|
@ -885,11 +846,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (runtime_config.verbosity >= 2) {
|
timing_info.gen_tok_sec = static_cast<double>(pos_offset - pos_gen_start) /
|
||||||
const double gen_end = hwy::platform::Now();
|
(hwy::platform::Now() - gen_start);
|
||||||
timing_info.gen_tok_sec =
|
|
||||||
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
|
|
@ -901,18 +859,13 @@ namespace gcpp {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct AllocatePrefill {
|
struct AllocateState {
|
||||||
ByteStorageT operator()() const {
|
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
|
||||||
return AllocateSizeof<Activations<TConfig, kPrefillBatchSize>>();
|
prefill = AllocateSizeof<Activations<TConfig, kPrefillBatchSize>>();
|
||||||
|
decode = AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TConfig>
|
|
||||||
struct AllocateDecode {
|
|
||||||
ByteStorageT operator()() const {
|
|
||||||
return AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
|
|
@ -922,8 +875,8 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
model_type_(model_type),
|
model_type_(model_type),
|
||||||
weight_type_(weight_type) {
|
weight_type_(weight_type) {
|
||||||
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
|
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
|
||||||
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
|
||||||
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
decode_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||||
|
|
@ -935,8 +888,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||||
HWY_ASSERT(weight_type == Type::kF32);
|
HWY_ASSERT(weight_type == Type::kF32);
|
||||||
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
|
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
|
||||||
model_type, pool);
|
model_type, pool);
|
||||||
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
|
||||||
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
decode_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
|
|
|
||||||
30
gemma/ops.h
30
gemma/ops.h
|
|
@ -1629,6 +1629,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
HWY_ATTR { return hn::Add(x, other); });
|
HWY_ATTR { return hn::Add(x, other); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||||
|
template <size_t kBatchSize, typename WeightT, typename OutT>
|
||||||
|
void RMSNormBatched(size_t num_tokens, const float* activations,
|
||||||
|
const WeightT* weights, OutT* out, const size_t model_dim) {
|
||||||
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
RMSNorm(activations + token_idx * model_dim, weights,
|
||||||
|
out + token_idx * model_dim, model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kBatchSize, typename WeightT, typename InOutT>
|
||||||
|
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
||||||
|
InOutT* inout, const size_t model_dim) {
|
||||||
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kBatchSize>
|
||||||
|
void AddFromBatched(size_t num_tokens, const float* other, float* x,
|
||||||
|
const size_t model_dim) {
|
||||||
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
||||||
|
model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||||
float* HWY_RESTRICT x, const size_t size,
|
float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t max_pos) {
|
const size_t max_pos) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue