mirror of https://github.com/google/gemma.cpp.git
parent
39d4115717
commit
6c0be20fa6
12
gemma/ops.h
12
gemma/ops.h
|
|
@ -29,6 +29,7 @@
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/detect_targets.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||||
|
|
@ -1247,8 +1248,15 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
hn::Transform(d, x, mask_pos,
|
hn::Transform(d, x, mask_pos,
|
||||||
[&vmax](const auto d, const auto value)
|
[&vmax](const auto d, const auto value) HWY_ATTR {
|
||||||
HWY_ATTR { return hn::Exp(d, hn::Sub(value, vmax)); });
|
#if HWY_TARGET & HWY_ALL_SVE
|
||||||
|
// Temporary workaround for buggy SVE codegen: avoid inlined
|
||||||
|
// Exp().
|
||||||
|
return hn::CallExp(d, hn::Sub(value, vmax));
|
||||||
|
#else
|
||||||
|
return hn::Exp(d, hn::Sub(value, vmax));
|
||||||
|
#endif
|
||||||
|
});
|
||||||
|
|
||||||
auto sum = hn::Zero(d);
|
auto sum = hn::Zero(d);
|
||||||
Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR {
|
Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR {
|
||||||
|
|
|
||||||
|
|
@ -109,12 +109,12 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
HWY_DASSERT(mask_pos <= size);
|
||||||
float sum = 0.0;
|
float sum = 0.0;
|
||||||
float maxval = *std::max_element(x, x + mask_pos);
|
const float maxval = *std::max_element(x, x + mask_pos);
|
||||||
for (size_t i = 0; i < mask_pos; ++i) {
|
for (size_t i = 0; i < mask_pos; ++i) {
|
||||||
x[i] = std::exp(x[i] - maxval);
|
x[i] = std::exp(x[i] - maxval);
|
||||||
sum += x[i];
|
sum += x[i];
|
||||||
}
|
}
|
||||||
float scale = 1.0f / sum;
|
const float scale = 1.0f / sum;
|
||||||
for (size_t i = 0; i < mask_pos; ++i) {
|
for (size_t i = 0; i < mask_pos; ++i) {
|
||||||
x[i] *= scale;
|
x[i] *= scale;
|
||||||
}
|
}
|
||||||
|
|
@ -237,6 +237,7 @@ struct TestMulByConst {
|
||||||
template <class D>
|
template <class D>
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
hwy::RandomState& rng) {
|
hwy::RandomState& rng) {
|
||||||
|
if (misalign_b == 0) return;
|
||||||
using T = hn::TFromD<D>;
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
|
@ -267,6 +268,7 @@ struct TestSoftmax {
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
hwy::RandomState& rng) {
|
hwy::RandomState& rng) {
|
||||||
if (count == 0) return; // *Softmax would assert
|
if (count == 0) return; // *Softmax would assert
|
||||||
|
if (misalign_b == 0) return;
|
||||||
using T = hn::TFromD<D>;
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue