Fix compilation error when `HWY_COMPILER_GCC_ACTUAL < 1300`

This commit is contained in:
RangerUFO 2024-03-28 14:54:37 +08:00
parent bb767d788d
commit 1c03d7446d
1 changed files with 13 additions and 3 deletions

View File

@ -464,8 +464,16 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
@ -479,7 +487,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
hwy::ThreadPool& inner_pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
static constexpr size_t kModelDim = TConfig::kModelDim;
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
@ -537,7 +546,8 @@ void Transformer(int token, size_t pos,
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {