diff --git a/gemma.cc b/gemma.cc index 4fe2782..8613846 100644 --- a/gemma.cc +++ b/gemma.cc @@ -35,6 +35,7 @@ #ifndef GEMMA_ONCE #define GEMMA_ONCE +#include // sqrtf #include #include @@ -426,6 +427,25 @@ HWY_NOINLINE void FFW(Activations& activations, activations.ffw_out.data() + batch_idx * kModelDim, pool); } +// __builtin_sqrt is not constexpr as of Clang 17. +#if HWY_COMPILER_GCC_ACTUAL && defined(HWY_HAVE_SCALAR_BF16_OPERATORS) && \ + HWY_HAVE_SCALAR_BF16_OPERATORS +#define GEMMA_CONSTEXPR_SQRT constexpr +static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { + return __builtin_sqrt(x); +} +#else +#define GEMMA_CONSTEXPR_SQRT +static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } +#endif + +template +GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + Sqrt(static_cast(TConfig::kModelDim)))); +} + template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const CompressedWeights& c_weights, @@ -434,8 +454,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = - static_cast(sqrt(static_cast(kModelDim))); + const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling(); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -490,12 +509,10 @@ void Transformer(int token, size_t pos, static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = - static_cast(sqrt(static_cast(kModelDim))); - Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); + const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling(); MulByConst(kEmbScaling, activations.x.data(), kModelDim); for (size_t layer = 0; layer < kLayers; ++layer) {