mirror of https://github.com/google/gemma.cpp.git
Use Lookup8 and detail::IsFull(d) in FastSigmoid
Fix targeted for scalable architectures PiperOrigin-RevId: 888633434
This commit is contained in:
parent
8a5e37eeb7
commit
259b757aef
|
|
@ -76,7 +76,8 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
|||
hn::Vec<D> b, c, d_coef;
|
||||
|
||||
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
|
||||
(HWY_HAVE_SCALABLE && sizeof(T) == 4)) {
|
||||
(HWY_HAVE_SCALABLE && sizeof(T) == 4 &&
|
||||
hn::detail::IsFull(d))) {
|
||||
// Coefficients for P(y/2) ~ index using CF algo
|
||||
const auto k0 = hn::Set(d, static_cast<T>(-0.1145426548151546));
|
||||
const auto k1 = hn::Set(d, static_cast<T>(3.4556654973457404));
|
||||
|
|
@ -118,23 +119,12 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
|||
static_cast<T>(1.1565092321921886), static_cast<T>(-0.7540678688218365),
|
||||
static_cast<T>(-3.767209600467866), static_cast<T>(-9.357047249878605)};
|
||||
|
||||
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
|
||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
||||
hn::CappedTag<T, 8> d8;
|
||||
b = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_b)), idx);
|
||||
c = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_c)), idx);
|
||||
d_coef =
|
||||
hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_d)), idx);
|
||||
} else {
|
||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
||||
hn::FixedTag<T, 4> d4;
|
||||
b = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_b),
|
||||
hn::Load(d4, arr_b + 4), idx);
|
||||
c = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_c),
|
||||
hn::Load(d4, arr_c + 4), idx);
|
||||
d_coef = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_d),
|
||||
hn::Load(d4, arr_d + 4), idx);
|
||||
}
|
||||
// Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this
|
||||
// condition covers all cases we encounter inside the top level if block
|
||||
// inside FastSigmoid
|
||||
b = hn::Lookup8(d, arr_b, idx_i);
|
||||
c = hn::Lookup8(d, arr_c, idx_i);
|
||||
d_coef = hn::Lookup8(d, arr_d, idx_i);
|
||||
} else {
|
||||
// --- FALLBACK PATH: Blend Chain ---
|
||||
// Thresholds for intervals
|
||||
|
|
|
|||
Loading…
Reference in New Issue