mirror of https://github.com/google/gemma.cpp.git
Adds simple-loop versions of missing batched functions.
PiperOrigin-RevId: 642189741
This commit is contained in:
parent
c7f5e93136
commit
c557ad23a8
119
gemma/gemma.cc
119
gemma/gemma.cc
|
|
@ -546,6 +546,37 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The below "batched" versions are just simple loops for now.
|
||||||
|
template <size_t kBatchSize, typename WeightT, typename OutT>
|
||||||
|
static void RMSNormBatched(size_t num_tokens, const float* activations,
|
||||||
|
const WeightT* weights, OutT* out,
|
||||||
|
const size_t model_dim) {
|
||||||
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
RMSNorm(activations + token_idx * model_dim, weights,
|
||||||
|
out + token_idx * model_dim, model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kBatchSize, typename WeightT, typename InOutT>
|
||||||
|
static void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
||||||
|
InOutT* inout, const size_t model_dim) {
|
||||||
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kBatchSize>
|
||||||
|
static void AddFromBatched(size_t num_tokens, const float* other, float* x,
|
||||||
|
const size_t model_dim) {
|
||||||
|
HWY_DASSERT(num_tokens <= kBatchSize);
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
||||||
|
model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Placeholder for internal test3, do not remove
|
// Placeholder for internal test3, do not remove
|
||||||
|
|
||||||
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
|
||||||
|
|
@ -580,12 +611,9 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
|
||||||
kModelDim);
|
|
||||||
}
|
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||||
layer_weights, kv_cache, pool);
|
layer_weights, kv_cache, pool);
|
||||||
|
|
@ -593,38 +621,29 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
|
||||||
layer_weights, kv_cache, pool);
|
layer_weights, kv_cache, pool);
|
||||||
}
|
}
|
||||||
|
if (TConfig::kPostNormScale) {
|
||||||
pool.Run(0, num_tokens, [&](const uint64_t token_idx,
|
RMSNormInplaceBatched<kBatchSize>(
|
||||||
size_t /*thread*/) HWY_ATTR {
|
num_tokens, layer_weights->post_attention_norm_scale.data(),
|
||||||
if (TConfig::kPostNormScale) {
|
activations.att_post2.data(), kModelDim);
|
||||||
RMSNormInplace(layer_weights->post_attention_norm_scale.data(),
|
|
||||||
activations.att_post2.data() + token_idx * kModelDim,
|
|
||||||
kModelDim);
|
|
||||||
}
|
|
||||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
|
||||||
activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim,
|
|
||||||
kModelDim);
|
|
||||||
});
|
|
||||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
if (TConfig::kPostNormScale) {
|
|
||||||
RMSNormInplace(layer_weights->post_ffw_norm_scale.data(),
|
|
||||||
activations.ffw_out.data() + token_idx * kModelDim,
|
|
||||||
kModelDim);
|
|
||||||
}
|
|
||||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
|
||||||
}
|
}
|
||||||
|
AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
|
||||||
|
activations.x.data(), kModelDim);
|
||||||
|
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
|
||||||
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
|
activations.bf_pre_ffw_rms_out.data(),
|
||||||
|
kModelDim);
|
||||||
|
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||||
|
if (TConfig::kPostNormScale) {
|
||||||
|
RMSNormInplaceBatched<kBatchSize>(
|
||||||
|
num_tokens, layer_weights->post_ffw_norm_scale.data(),
|
||||||
|
activations.ffw_out.data(), kModelDim);
|
||||||
|
}
|
||||||
|
AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
|
||||||
|
activations.x.data(), kModelDim);
|
||||||
} // foreach layer
|
} // foreach layer
|
||||||
|
|
||||||
pool.Run(
|
RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
activations.x.data(), kModelDim);
|
||||||
RMSNormInplace(weights.final_norm_scale.data(),
|
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// n = 1 specialization
|
// n = 1 specialization
|
||||||
|
|
@ -654,9 +673,9 @@ HWY_NOINLINE void Transformer(int token, size_t pos,
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
RMSNorm(activations.x.data(),
|
RMSNormBatched<1>(1, activations.x.data(),
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
|
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache,
|
||||||
pool);
|
pool);
|
||||||
|
|
@ -665,18 +684,22 @@ HWY_NOINLINE void Transformer(int token, size_t pos,
|
||||||
kv_cache, pool);
|
kv_cache, pool);
|
||||||
}
|
}
|
||||||
if (TConfig::kPostNormScale) {
|
if (TConfig::kPostNormScale) {
|
||||||
RMSNormInplace(layer_weights->post_attention_norm_scale.data(),
|
RMSNormInplaceBatched<1>(1,
|
||||||
activations.att_post2.data(), kModelDim);
|
layer_weights->post_attention_norm_scale.data(),
|
||||||
|
activations.att_post2.data(), kModelDim);
|
||||||
}
|
}
|
||||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
AddFromBatched<1>(1, activations.att_post2.data(), activations.x.data(),
|
||||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
kModelDim);
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
RMSNormBatched<1>(1, activations.x.data(),
|
||||||
|
layer_weights->pre_ffw_norm_scale.data(),
|
||||||
|
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||||
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
||||||
if (TConfig::kPostNormScale) {
|
if (TConfig::kPostNormScale) {
|
||||||
RMSNormInplace(layer_weights->post_ffw_norm_scale.data(),
|
RMSNormInplaceBatched<1>(1, layer_weights->post_ffw_norm_scale.data(),
|
||||||
activations.ffw_out.data(), kModelDim);
|
activations.ffw_out.data(), kModelDim);
|
||||||
}
|
}
|
||||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
AddFromBatched<1>(1, activations.ffw_out.data(), activations.x.data(),
|
||||||
|
kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
std::string block_name = "blocks." + std::to_string(layer);
|
std::string block_name = "blocks." + std::to_string(layer);
|
||||||
(*layers_output)(pos, block_name, activations.x.data(), kModelDim);
|
(*layers_output)(pos, block_name, activations.x.data(), kModelDim);
|
||||||
|
|
@ -685,8 +708,8 @@ HWY_NOINLINE void Transformer(int token, size_t pos,
|
||||||
|
|
||||||
// Placeholder for internal test4, do not remove
|
// Placeholder for internal test4, do not remove
|
||||||
|
|
||||||
RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(),
|
RMSNormInplaceBatched<1>(1, weights.final_norm_scale.data(),
|
||||||
kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
if (layers_output != nullptr) {
|
if (layers_output != nullptr) {
|
||||||
(*layers_output)(pos, "final_norm", activations.x.data(), kModelDim);
|
(*layers_output)(pos, "final_norm", activations.x.data(), kModelDim);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
23
gemma/ops.h
23
gemma/ops.h
|
|
@ -942,18 +942,20 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// float, float -> float; simple loop.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
float ss = SquaredL2(x, size);
|
float ss = SquaredL2(x, size);
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
for (size_t j = 0; j < size; j++) {
|
for (size_t j = 0; j < size; j++) {
|
||||||
// Note 1.0f centering here
|
// Note 1.0f centering here
|
||||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// x=f, w=bf16 -> out=f
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
|
|
@ -984,11 +986,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// float -> float; simple loop.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
float ss = SquaredL2(inout, size);
|
float ss = SquaredL2(inout, size);
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
for (size_t j = 0; j < size; j++) {
|
for (size_t j = 0; j < size; j++) {
|
||||||
// Note 1.0f centering here
|
// Note 1.0f centering here
|
||||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
||||||
|
|
@ -1005,10 +1008,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
using VF = hn::Vec<decltype(df32)>;
|
||||||
const size_t N32 = hn::Lanes(df32);
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
const float ss = SquaredL2(inout, size);
|
const float ss = SquaredL2(inout, size);
|
||||||
const VF vss =
|
const VF vss =
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
|
@ -1034,10 +1037,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
using VF = hn::Vec<decltype(df32)>;
|
||||||
const size_t N32 = hn::Lanes(df32);
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
const float ss = SquaredL2(x, size);
|
const float ss = SquaredL2(x, size);
|
||||||
const VF vss =
|
const VF vss =
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
|
@ -1062,10 +1065,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
using VF = hn::Vec<decltype(df32)>;
|
||||||
const size_t N32 = hn::Lanes(df32);
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
constexpr float eps = 1e-6f;
|
constexpr float kEps = 1e-6f;
|
||||||
const float ss = SquaredL2(x, size);
|
const float ss = SquaredL2(x, size);
|
||||||
const VF vss =
|
const VF vss =
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue