mirror of https://github.com/google/gemma.cpp.git
Use bf16-rounded sqrt for scaling embeddings to match Gemma
Thanks Daniel & Michael Han for pointing this out. https://unsloth.ai/blog/gemma-bugs PiperOrigin-RevId: 615250003
This commit is contained in:
parent
0221956b2e
commit
5fa2eb1a86
27
gemma.cc
27
gemma.cc
|
|
@ -35,6 +35,7 @@
|
||||||
#ifndef GEMMA_ONCE
|
#ifndef GEMMA_ONCE
|
||||||
#define GEMMA_ONCE
|
#define GEMMA_ONCE
|
||||||
|
|
||||||
|
#include <math.h> // sqrtf
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
|
@ -426,6 +427,25 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||||
|
#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \
|
||||||
|
HWY_HAVE_SCALAR_BF16_OPERATORS
|
||||||
|
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||||
|
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||||
|
return __builtin_sqrt(x);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#define GEMMA_CONSTEXPR_SQRT
|
||||||
|
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
|
||||||
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
|
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||||
|
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename TConfig, size_t kBatchSize>
|
template <typename TConfig, size_t kBatchSize>
|
||||||
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
const CompressedWeights<TConfig>& c_weights,
|
const CompressedWeights<TConfig>& c_weights,
|
||||||
|
|
@ -434,8 +454,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
hwy::ThreadPool& inner_pool) {
|
hwy::ThreadPool& inner_pool) {
|
||||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static const float kEmbScaling =
|
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
|
||||||
|
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
@ -490,12 +509,10 @@ void Transformer(int token, size_t pos,
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
static constexpr size_t kLayers = TConfig::kLayers;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
||||||
static const float kEmbScaling =
|
|
||||||
static_cast<float>(sqrt(static_cast<double>(kModelDim)));
|
|
||||||
|
|
||||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
|
|
||||||
|
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue