mirror of https://github.com/google/gemma.cpp.git
Use highway in AddFrom, MulBy, MulByConst, MulByConstAndAdd, create_distribution
This commit is contained in:
parent
8fb44ed6dd
commit
858d5b08c2
86
ops.h
86
ops.h
|
|
@ -551,47 +551,66 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) {
|
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
||||||
for (size_t i = 0; i < size; ++i) {
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
x[i] += other[i];
|
using D = hn::ScalableTag<float>;
|
||||||
}
|
const D d;
|
||||||
|
|
||||||
|
hn::Transform1(d, x, size, other,
|
||||||
|
[](const auto d, const auto x, const auto other)
|
||||||
|
HWY_ATTR { return hn::Add(x, other); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||||
float* HWY_RESTRICT x, size_t size,
|
float* HWY_RESTRICT x, const size_t size,
|
||||||
size_t max_pos) {
|
const size_t max_pos) {
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
x[i] *= other[i];
|
using D = hn::ScalableTag<float>;
|
||||||
}
|
const D d;
|
||||||
|
|
||||||
|
hn::Transform1(d, x, max_pos, other,
|
||||||
|
[](const auto d, const auto x, const auto other)
|
||||||
|
HWY_ATTR { return hn::Mul(x, other); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
||||||
float* HWY_RESTRICT x,
|
float* HWY_RESTRICT x,
|
||||||
size_t size) {
|
const size_t size) {
|
||||||
return MulBy(other, x, size, size);
|
return MulBy(other, x, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void MulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
||||||
size_t max_pos) {
|
const size_t size, const size_t max_pos) {
|
||||||
HWY_DASSERT(max_pos <= size);
|
HWY_DASSERT(max_pos <= size);
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
x[i] *= c;
|
using D = hn::ScalableTag<float>;
|
||||||
}
|
const D d;
|
||||||
|
const auto constant = hn::Set(d, c);
|
||||||
|
hn::Transform(d, x, max_pos,
|
||||||
|
[&constant](const auto d, const auto x)
|
||||||
|
HWY_ATTR { return hn::Mul(x, constant); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(float c,
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
||||||
float* HWY_RESTRICT x,
|
float* HWY_RESTRICT x,
|
||||||
size_t size) {
|
const size_t size) {
|
||||||
MulByConst(c, x, size, size);
|
MulByConst(c, x, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void MulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
static HWY_NOINLINE void MulByConstAndAdd(const float c,
|
||||||
float* HWY_RESTRICT out, size_t size,
|
const float* HWY_RESTRICT x,
|
||||||
size_t max_pos) {
|
float* HWY_RESTRICT out,
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
const size_t size,
|
||||||
out[i] += x[i] * c;
|
const size_t max_pos) {
|
||||||
}
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
const D d;
|
||||||
|
const auto constant = hn::Set(d, c);
|
||||||
|
hn::Transform1(
|
||||||
|
d, out, max_pos, x,
|
||||||
|
[&constant](const auto d, const auto out_element, const auto x_element)
|
||||||
|
HWY_ATTR { return hn::MulAdd(x_element, constant, out_element); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
|
|
@ -693,15 +712,18 @@ template <size_t k>
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||||
// re-normalize distribution
|
// re-normalize distribution
|
||||||
for (size_t i = 0; i < k; ++i) {
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
top_k[i] = exp(log(top_k[i]) / temperature);
|
using D = hn::ScalableTag<float>;
|
||||||
}
|
const D d;
|
||||||
float denominator = 0.0f;
|
|
||||||
for (size_t i = 0; i < k; ++i) {
|
const auto one = hn::Set(d, 1.0f);
|
||||||
denominator += top_k[i];
|
const auto temperature_inv = hn::Div(one, hn::Set(d, temperature));
|
||||||
}
|
|
||||||
denominator = 1.0f / denominator;
|
hn::Transform(d, top_k.data(), top_k.size(),
|
||||||
MulByConst(denominator, top_k.data(), k);
|
[&temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
|
||||||
|
return hn::Mul(hn::Exp(d, hn::Log(d, v)), temperature_inv);
|
||||||
|
});
|
||||||
|
|
||||||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue