diff --git a/gemma.cc b/gemma.cc index 23fc31f..ff58a82 100644 --- a/gemma.cc +++ b/gemma.cc @@ -464,8 +464,16 @@ HWY_NOINLINE void FFW(Activations& activations, activations.ffw_out.data() + batch_idx * kModelDim, pool); } +// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` +// are both constexpr +#if HWY_COMPILER_GCC_ACTUAL +#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR +#else +#define GEMMA_CONSTEXPR_EMBSCALING +#endif + template -GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() { +GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { // Round to bf16 to match Gemma's Embedder, which casts before mul. return hwy::ConvertScalarTo(hwy::ConvertScalarTo( Sqrt(static_cast(TConfig::kModelDim)))); @@ -479,7 +487,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling(); + GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = + EmbeddingScaling(); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -537,7 +546,8 @@ void Transformer(int token, size_t pos, Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); - const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling(); + GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = + EmbeddingScaling(); MulByConst(kEmbScaling, activations.x.data(), kModelDim); for (size_t layer = 0; layer < kLayers; ++layer) {