mirror of https://github.com/google/gemma.cpp.git
Fix compilation error when `HWY_COMPILER_GCC_ACTUAL < 1300`
This commit is contained in:
parent
bb767d788d
commit
1c03d7446d
16
gemma.cc
16
gemma.cc
|
|
@ -464,8 +464,16 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||||
|
// are both constexpr
|
||||||
|
#if HWY_COMPILER_GCC_ACTUAL
|
||||||
|
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
|
||||||
|
#else
|
||||||
|
#define GEMMA_CONSTEXPR_EMBSCALING
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
GEMMA_CONSTEXPR_SQRT float EmbeddingScaling() {
|
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||||
|
|
@ -479,7 +487,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
hwy::ThreadPool& inner_pool) {
|
hwy::ThreadPool& inner_pool) {
|
||||||
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
|
EmbeddingScaling<TConfig>();
|
||||||
|
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
@ -537,7 +546,8 @@ void Transformer(int token, size_t pos,
|
||||||
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
Decompress(c_weights.c_embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
|
|
||||||
const GEMMA_CONSTEXPR_SQRT float kEmbScaling = EmbeddingScaling<TConfig>();
|
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||||
|
EmbeddingScaling<TConfig>();
|
||||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue