mirror of https://github.com/google/gemma.cpp.git
SVE build fix: avoid capturing vectors directly.
Also use more V typedef instead of auto. PiperOrigin-RevId: 651423685
This commit is contained in:
parent
be765afce2
commit
edaf61b983
81
gemma/ops.h
81
gemma/ops.h
|
|
@ -1132,10 +1132,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
hn::Transform1(d, x, size, other,
|
hn::Transform1(D(), x, size, other,
|
||||||
[](const auto d, const auto x, const auto other)
|
[](const auto d, const V x, const V other)
|
||||||
HWY_ATTR { return hn::Add(x, other); });
|
HWY_ATTR { return hn::Add(x, other); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1175,10 +1175,10 @@ static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
hn::Transform1(d, x, max_pos, other,
|
hn::Transform1(D(), x, max_pos, other,
|
||||||
[](const auto d, const auto x, const auto other)
|
[](const auto d, const V x, const V other)
|
||||||
HWY_ATTR { return hn::Mul(x, other); });
|
HWY_ATTR { return hn::Mul(x, other); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1193,11 +1193,10 @@ static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
using V = hn::Vec<D>;
|
||||||
const auto constant = hn::Set(d, c);
|
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR {
|
||||||
hn::Transform(d, x, max_pos,
|
return hn::Mul(x, hn::Set(d, c));
|
||||||
[&constant](const auto d, const auto x)
|
});
|
||||||
HWY_ATTR { return hn::Mul(x, constant); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
||||||
|
|
@ -1213,12 +1212,11 @@ static HWY_NOINLINE void MulByConstAndAdd(const float c,
|
||||||
const size_t max_pos) {
|
const size_t max_pos) {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
using V = hn::Vec<D>;
|
||||||
const auto constant = hn::Set(d, c);
|
hn::Transform1(D(), out, max_pos, x,
|
||||||
hn::Transform1(
|
[c](const auto d, const V v_out, const V v_x) HWY_ATTR {
|
||||||
d, out, max_pos, x,
|
return hn::MulAdd(v_x, hn::Set(d, c), v_out);
|
||||||
[&constant](const auto d, const auto out_element, const auto x_element)
|
});
|
||||||
HWY_ATTR { return hn::MulAdd(x_element, constant, out_element); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
|
|
@ -1234,30 +1232,32 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
const D d;
|
const D d;
|
||||||
|
|
||||||
const auto vmin = hn::Set(d, hwy::LowestValue<float>());
|
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||||
auto vmax = vmin;
|
V vmax = vmin;
|
||||||
Foreach(d, x, mask_pos, vmin,
|
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||||
[&vmax](const auto d, const auto value)
|
Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
HWY_ATTR { vmax = hn::Max(vmax, value); });
|
*pmax = hn::Max(*pmax, value);
|
||||||
|
});
|
||||||
vmax = hn::MaxOfLanes(d, vmax);
|
vmax = hn::MaxOfLanes(d, vmax);
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
hn::Transform(d, x, mask_pos,
|
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
[&vmax](const auto d, const auto value) HWY_ATTR {
|
|
||||||
#if HWY_TARGET & HWY_ALL_SVE
|
#if HWY_TARGET & HWY_ALL_SVE
|
||||||
// Temporary workaround for buggy SVE codegen: avoid inlined
|
// Temporary workaround for buggy SVE codegen: avoid inlined
|
||||||
// Exp().
|
// Exp().
|
||||||
return hn::CallExp(d, hn::Sub(value, vmax));
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
#else
|
#else
|
||||||
return hn::Exp(d, hn::Sub(value, vmax));
|
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||||
#endif
|
#endif
|
||||||
});
|
});
|
||||||
|
|
||||||
auto sum = hn::Zero(d);
|
V sum = hn::Zero(d);
|
||||||
Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR {
|
V* psum = ∑
|
||||||
sum = hn::Add(sum, value);
|
Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR {
|
||||||
|
*psum = hn::Add(*psum, value);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Normalize to probability distribution
|
// Normalize to probability distribution
|
||||||
|
|
@ -1277,14 +1277,13 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
|
||||||
using V = hn::Vec<D>;
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
const V vcap = hn::Set(d, cap);
|
const float inv_cap = 1.0f / cap;
|
||||||
const V vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
|
||||||
|
|
||||||
hn::Transform(d, x, max_pos, [&vcap, &vinv_cap](D d, hn::Vec<D> v) HWY_ATTR {
|
hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR {
|
||||||
return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap)));
|
return hn::Mul(hn::Set(d, cap),
|
||||||
|
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1310,17 +1309,15 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||||
template <size_t k>
|
template <size_t k>
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||||
// re-normalize distribution
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const D d;
|
|
||||||
|
|
||||||
const auto temperature_inv =
|
// re-normalize distribution
|
||||||
hn::Div(hn::Set(d, 1.0f), hn::Set(d, temperature));
|
const float temperature_inv = 1.0f / temperature;
|
||||||
|
hn::Transform(D(), top_k.data(), top_k.size(),
|
||||||
hn::Transform(d, top_k.data(), top_k.size(),
|
[temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
|
||||||
[&temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
|
return hn::Exp(
|
||||||
return hn::Exp(d, hn::Mul(hn::Log(d, v), temperature_inv));
|
d, hn::Mul(hn::Log(d, v), hn::Set(d, temperature_inv)));
|
||||||
});
|
});
|
||||||
|
|
||||||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue