mirror of https://github.com/google/gemma.cpp.git
Use Lookup8 and detail::IsFull(d) in FastSigmoid
Fix targeted for scalable architectures PiperOrigin-RevId: 888633434
This commit is contained in:
parent
8a5e37eeb7
commit
259b757aef
|
|
@ -76,7 +76,8 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
||||||
hn::Vec<D> b, c, d_coef;
|
hn::Vec<D> b, c, d_coef;
|
||||||
|
|
||||||
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
|
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
|
||||||
(HWY_HAVE_SCALABLE && sizeof(T) == 4)) {
|
(HWY_HAVE_SCALABLE && sizeof(T) == 4 &&
|
||||||
|
hn::detail::IsFull(d))) {
|
||||||
// Coefficients for P(y/2) ~ index using CF algo
|
// Coefficients for P(y/2) ~ index using CF algo
|
||||||
const auto k0 = hn::Set(d, static_cast<T>(-0.1145426548151546));
|
const auto k0 = hn::Set(d, static_cast<T>(-0.1145426548151546));
|
||||||
const auto k1 = hn::Set(d, static_cast<T>(3.4556654973457404));
|
const auto k1 = hn::Set(d, static_cast<T>(3.4556654973457404));
|
||||||
|
|
@ -118,23 +119,12 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
||||||
static_cast<T>(1.1565092321921886), static_cast<T>(-0.7540678688218365),
|
static_cast<T>(1.1565092321921886), static_cast<T>(-0.7540678688218365),
|
||||||
static_cast<T>(-3.767209600467866), static_cast<T>(-9.357047249878605)};
|
static_cast<T>(-3.767209600467866), static_cast<T>(-9.357047249878605)};
|
||||||
|
|
||||||
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
|
// Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this
|
||||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
// condition covers all cases we encounter inside the top level if block
|
||||||
hn::CappedTag<T, 8> d8;
|
// inside FastSigmoid
|
||||||
b = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_b)), idx);
|
b = hn::Lookup8(d, arr_b, idx_i);
|
||||||
c = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_c)), idx);
|
c = hn::Lookup8(d, arr_c, idx_i);
|
||||||
d_coef =
|
d_coef = hn::Lookup8(d, arr_d, idx_i);
|
||||||
hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_d)), idx);
|
|
||||||
} else {
|
|
||||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
|
||||||
hn::FixedTag<T, 4> d4;
|
|
||||||
b = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_b),
|
|
||||||
hn::Load(d4, arr_b + 4), idx);
|
|
||||||
c = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_c),
|
|
||||||
hn::Load(d4, arr_c + 4), idx);
|
|
||||||
d_coef = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_d),
|
|
||||||
hn::Load(d4, arr_d + 4), idx);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// --- FALLBACK PATH: Blend Chain ---
|
// --- FALLBACK PATH: Blend Chain ---
|
||||||
// Thresholds for intervals
|
// Thresholds for intervals
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue