diff --git a/gemma/ops.h b/gemma/ops.h index 691bb2c..702efc5 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -29,6 +29,7 @@ #include "compression/sfp.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/detect_targets.h" #include "hwy/profiler.h" #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ @@ -1247,8 +1248,15 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // Subtract max (avoid precision loss for large exponents) and exponentiate. hn::Transform(d, x, mask_pos, - [&vmax](const auto d, const auto value) - HWY_ATTR { return hn::Exp(d, hn::Sub(value, vmax)); }); + [&vmax](const auto d, const auto value) HWY_ATTR { +#if HWY_TARGET & HWY_ALL_SVE + // Temporary workaround for buggy SVE codegen: avoid inlined + // Exp(). + return hn::CallExp(d, hn::Sub(value, vmax)); +#else + return hn::Exp(d, hn::Sub(value, vmax)); +#endif + }); auto sum = hn::Zero(d); Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR { diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index a7f47f4..5f3c45f 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -109,12 +109,12 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); float sum = 0.0; - float maxval = *std::max_element(x, x + mask_pos); + const float maxval = *std::max_element(x, x + mask_pos); for (size_t i = 0; i < mask_pos; ++i) { x[i] = std::exp(x[i] - maxval); sum += x[i]; } - float scale = 1.0f / sum; + const float scale = 1.0f / sum; for (size_t i = 0; i < mask_pos; ++i) { x[i] *= scale; } @@ -237,6 +237,7 @@ struct TestMulByConst { template void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { + if (misalign_b == 0) return; using T = hn::TFromD; hwy::AlignedFreeUniquePtr px = @@ -267,6 +268,7 @@ struct TestSoftmax { void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, hwy::RandomState& rng) { if (count == 0) return; // *Softmax would assert + if (misalign_b == 0) return; using T = hn::TFromD; hwy::AlignedFreeUniquePtr px =