Use bf16-rounded sqrt for scaling embeddings to match Gemma

Thanks Daniel & Michael Han for pointing this out.
https://unsloth.ai/blog/gemma-bugs

PiperOrigin-RevId: 615250003
This commit is contained in:
Jan Wassenberg 2024-03-12 19:14:31 -07:00 committed by Copybara-Service
parent 0221956b2e
commit 5fa2eb1a86
1 changed files with 22 additions and 5 deletions

View File

@ -35,6 +35,7 @@
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdio.h>
@ -426,6 +427,25 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \
HWY_HAVE_SCALAR_BF16_OPERATORS
#define GEMMA_CONSTEXPR_SQRT constexpr
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
return __builtin_sqrt(x);
}
#else
#define GEMMA_CONSTEXPR_SQRT
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
template <typename TConfig, size_t kBatchSize>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const CompressedWeights<TConfig>& c_weights,
@ -434,8 +454,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
@ -490,12 +509,10 @@ void Transformer(int token, size_t pos,
static constexpr size_t kLayers = TConfig::kLayers;
static constexpr size_t kModelDim = TConfig::kModelDim;
static const float kEmbScaling =
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {