diff --git a/gemma/ops.h b/gemma/ops.h index ed12ef4..90b0d13 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -833,14 +833,16 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, vmax = hn::MaxOfLanes(d, vmax); // Subtract max (avoid precision loss for large exponents) and exponentiate. - auto sum = hn::Zero(d); hn::Transform(d, x, mask_pos, - [&sum, &vmax](const auto d, const auto value) HWY_ATTR { - const auto out = hn::Exp(d, hn::Sub(value, vmax)); - sum = hn::Add(sum, out); - return out; + [&vmax](const auto d, const auto value) HWY_ATTR { + return hn::Exp(d, hn::Sub(value, vmax)); }); + auto sum = hn::Zero(d); + Foreach(d, x, mask_pos, sum, + [&sum](const auto d, const auto value) + HWY_ATTR { sum = hn::Add(sum, value); }); + // Normalize to probability distribution const float mul = 1.0f / hn::ReduceSum(d, sum); MulByConst(mul, x, size, mask_pos); diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 9ad185a..5d59f63 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -107,42 +107,16 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, size_t mask_pos) { HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); - - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D d; - const size_t N = hn::Lanes(d); - - const hn::Vec vmin = hn::Set(d, hwy::LowestValue()); - hn::Vec vmax = vmin; - size_t idx = 0; - if (mask_pos >= N) { - for (; idx <= mask_pos - N; idx += N) { - vmax = hn::Max(vmax, LoadU(d, x + idx)); - } + float sum = 0.0; + float maxval = *std::max_element(x, x + mask_pos); + for (size_t i = 0; i < mask_pos; ++i) { + x[i] = std::exp(x[i] - maxval); + sum += x[i]; } - vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx)); - vmax = hn::MaxOfLanes(d, vmax); // broadcast - - hn::Vec sum = hn::Zero(d); - idx = 0; - if (mask_pos >= N) { - for (; idx <= mask_pos - N; idx += N) { - const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax)); - sum = hn::Add(sum, out); - hn::StoreU(out, d, x + idx); - } + float scale = 1.0f / sum; + for (size_t i = 0; i < mask_pos; ++i) { + x[i] *= scale; } - if (mask_pos > idx) { - const size_t remaining = mask_pos - idx; - const hn::Vec out = - hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax)); - sum = hn::Add(sum, out); - hn::StoreN(out, d, x + idx, remaining); - } - - const float mul = 1.0f / hn::ReduceSum(d, sum); - SourceMulByConst(mul, x, size, mask_pos); } template @@ -311,8 +285,14 @@ struct TestSoftmax { SourceSoftmax(e, count, count); Softmax(x, count, count); - hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); + T sum = 0.0f; + for (size_t i = 0; i < count; ++i) { + sum += x[i]; + double rel = std::abs(x[i] - e[i]) / e[i]; + ASSERT_LT(rel, 1e-6) + << "Mismatch on coordinate " << i << " out of " << count; + } + ASSERT_NEAR(sum, 1.0, 1e-6); } };