diff --git a/gemma/ops.h b/gemma/ops.h index 3415c4c..ef82f64 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1132,10 +1132,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - const D d; + using V = hn::Vec; - hn::Transform1(d, x, size, other, - [](const auto d, const auto x, const auto other) + hn::Transform1(D(), x, size, other, + [](const auto d, const V x, const V other) HWY_ATTR { return hn::Add(x, other); }); } @@ -1175,10 +1175,10 @@ static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - const D d; + using V = hn::Vec; - hn::Transform1(d, x, max_pos, other, - [](const auto d, const auto x, const auto other) + hn::Transform1(D(), x, max_pos, other, + [](const auto d, const V x, const V other) HWY_ATTR { return hn::Mul(x, other); }); } @@ -1193,11 +1193,10 @@ static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x, HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - const D d; - const auto constant = hn::Set(d, c); - hn::Transform(d, x, max_pos, - [&constant](const auto d, const auto x) - HWY_ATTR { return hn::Mul(x, constant); }); + using V = hn::Vec; + hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR { + return hn::Mul(x, hn::Set(d, c)); + }); } static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, @@ -1213,12 +1212,11 @@ static HWY_NOINLINE void MulByConstAndAdd(const float c, const size_t max_pos) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - const D d; - const auto constant = hn::Set(d, c); - hn::Transform1( - d, out, max_pos, x, - [&constant](const auto d, const auto out_element, const auto x_element) - HWY_ATTR { return hn::MulAdd(x_element, constant, out_element); }); + using V = hn::Vec; + hn::Transform1(D(), out, max_pos, x, + [c](const auto d, const V v_out, const V v_x) HWY_ATTR { + return hn::MulAdd(v_x, hn::Set(d, c), v_out); + }); } static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( @@ -1234,30 +1232,32 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; + using V = hn::Vec; const D d; - const auto vmin = hn::Set(d, hwy::LowestValue()); - auto vmax = vmin; - Foreach(d, x, mask_pos, vmin, - [&vmax](const auto d, const auto value) - HWY_ATTR { vmax = hn::Max(vmax, value); }); + const V vmin = hn::Set(d, hwy::LowestValue()); + V vmax = vmin; + V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly + Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR { + *pmax = hn::Max(*pmax, value); + }); vmax = hn::MaxOfLanes(d, vmax); // Subtract max (avoid precision loss for large exponents) and exponentiate. - hn::Transform(d, x, mask_pos, - [&vmax](const auto d, const auto value) HWY_ATTR { + hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR { #if HWY_TARGET & HWY_ALL_SVE - // Temporary workaround for buggy SVE codegen: avoid inlined - // Exp(). - return hn::CallExp(d, hn::Sub(value, vmax)); + // Temporary workaround for buggy SVE codegen: avoid inlined + // Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); #else - return hn::Exp(d, hn::Sub(value, vmax)); + return hn::Exp(d, hn::Sub(value, *pmax)); #endif - }); + }); - auto sum = hn::Zero(d); - Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR { - sum = hn::Add(sum, value); + V sum = hn::Zero(d); + V* psum = ∑ + Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR { + *psum = hn::Add(*psum, value); }); // Normalize to probability distribution @@ -1277,14 +1277,13 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - const D d; using V = hn::Vec; - const V vcap = hn::Set(d, cap); - const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); + const float inv_cap = 1.0f / cap; - hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec v) HWY_ATTR { - return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap))); + hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR { + return hn::Mul(hn::Set(d, cap), + hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); }); } @@ -1310,17 +1309,15 @@ SampleArgmax(const float* probabilities, size_t vocab_size) { template static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution(std::array& top_k, float temperature) { - // re-normalize distribution namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - const D d; - const auto temperature_inv = - hn::Div(hn::Set(d, 1.0f), hn::Set(d, temperature)); - - hn::Transform(d, top_k.data(), top_k.size(), - [&temperature_inv](D d, hn::Vec v) HWY_ATTR { - return hn::Exp(d, hn::Mul(hn::Log(d, v), temperature_inv)); + // re-normalize distribution + const float temperature_inv = 1.0f / temperature; + hn::Transform(D(), top_k.data(), top_k.size(), + [temperature_inv](D d, hn::Vec v) HWY_ATTR { + return hn::Exp( + d, hn::Mul(hn::Log(d, v), hn::Set(d, temperature_inv))); }); return std::discrete_distribution(std::begin(top_k), std::end(top_k));