mirror of https://github.com/google/gemma.cpp.git
Use bf16-rounded sqrt for scaling embeddings to match Gemma
Thanks Daniel & Michael Han for pointing this out. https://unsloth.ai/blog/gemma-bugs PiperOrigin-RevId: 615250003
This commit is contained in:
parent
0221956b2e
commit
5fa2eb1a86
27
gemma.cc
27
gemma.cc
|
|
@ -35,6 +35,7 @@
|
|||
#ifndef GEMMA_ONCE
|
||||
#define GEMMA_ONCE
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -426,6 +427,25 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||
}
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \
|
||||
HWY_HAVE_SCALAR_BF16_OPERATORS
|
||||
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||
return __builtin_sqrt(x);
|
||||
}
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_SQRT
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
template <typename TConfig, size_t kBatchSize>
|
||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||
const CompressedWeights<TConfig>& c_weights,
|
||||
|
|
@ -434,8 +454,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
hwy::ThreadPool& inner_pool) {
|
||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static const float kEmbScaling =
|
||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
|
||||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -490,12 +509,10 @@ void Transformer(int token, size_t pos,
|
|||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
||||
static const float kEmbScaling =
|
||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
||||
|
||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data(), kModelDim);
|
||||
|
||||
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue