mirror of https://github.com/google/gemma.cpp.git
parent
39d4115717
commit
6c0be20fa6
12
gemma/ops.h
12
gemma/ops.h
|
|
@ -29,6 +29,7 @@
|
|||
#include "compression/sfp.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||
|
|
@ -1247,8 +1248,15 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
hn::Transform(d, x, mask_pos,
|
||||
[&vmax](const auto d, const auto value)
|
||||
HWY_ATTR { return hn::Exp(d, hn::Sub(value, vmax)); });
|
||||
[&vmax](const auto d, const auto value) HWY_ATTR {
|
||||
#if HWY_TARGET & HWY_ALL_SVE
|
||||
// Temporary workaround for buggy SVE codegen: avoid inlined
|
||||
// Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, vmax));
|
||||
#else
|
||||
return hn::Exp(d, hn::Sub(value, vmax));
|
||||
#endif
|
||||
});
|
||||
|
||||
auto sum = hn::Zero(d);
|
||||
Foreach(d, x, mask_pos, sum, [&sum](const auto d, const auto value) HWY_ATTR {
|
||||
|
|
|
|||
|
|
@ -109,12 +109,12 @@ HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
|||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
float sum = 0.0;
|
||||
float maxval = *std::max_element(x, x + mask_pos);
|
||||
const float maxval = *std::max_element(x, x + mask_pos);
|
||||
for (size_t i = 0; i < mask_pos; ++i) {
|
||||
x[i] = std::exp(x[i] - maxval);
|
||||
sum += x[i];
|
||||
}
|
||||
float scale = 1.0f / sum;
|
||||
const float scale = 1.0f / sum;
|
||||
for (size_t i = 0; i < mask_pos; ++i) {
|
||||
x[i] *= scale;
|
||||
}
|
||||
|
|
@ -237,6 +237,7 @@ struct TestMulByConst {
|
|||
template <class D>
|
||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||
hwy::RandomState& rng) {
|
||||
if (misalign_b == 0) return;
|
||||
using T = hn::TFromD<D>;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||
|
|
@ -267,6 +268,7 @@ struct TestSoftmax {
|
|||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||
hwy::RandomState& rng) {
|
||||
if (count == 0) return; // *Softmax would assert
|
||||
if (misalign_b == 0) return;
|
||||
using T = hn::TFromD<D>;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||
|
|
|
|||
Loading…
Reference in New Issue