mirror of https://github.com/google/gemma.cpp.git
Change calculation from (ax+b)/(cx+d) to (x + b')/(c'x+ d') this replaces a MulAdd with Add reducing port contention on modern cpus and thus increasing throughput.
Also reduces the need for 1 register to hold b as 1.0 here PiperOrigin-RevId: 886170146
This commit is contained in:
parent
ceb70203f0
commit
50144738f1
|
|
@ -73,7 +73,7 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
|||
auto y = hn::Abs(val);
|
||||
|
||||
constexpr size_t kLanes = HWY_MAX_LANES_D(D);
|
||||
hn::Vec<D> a, c, d_coef;
|
||||
hn::Vec<D> b, c, d_coef;
|
||||
|
||||
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
|
||||
(HWY_HAVE_SCALABLE && sizeof(T) == 4)) {
|
||||
|
|
@ -98,41 +98,38 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
|||
// Clamp index to 7
|
||||
idx_i = hn::Min(idx_i, hn::Set(DI(), 7));
|
||||
|
||||
HWY_ALIGN static constexpr T arr_a[8] = {
|
||||
static_cast<T>(-1435.326650329326),
|
||||
static_cast<T>(-96.9456723845743),
|
||||
static_cast<T>(-18.628915468855695),
|
||||
static_cast<T>(-5.90191111348809),
|
||||
static_cast<T>(-2.356433838423728),
|
||||
static_cast<T>(-1.0464246812594584),
|
||||
static_cast<T>(-0.4801959711368016),
|
||||
static_cast<T>(-0.2132727031175401)};
|
||||
HWY_ALIGN static constexpr T arr_c[8] = {static_cast<T>(-316.5640994591445),
|
||||
static_cast<T>(-49.14374182730444),
|
||||
static_cast<T>(-15.69264419046708),
|
||||
static_cast<T>(-6.949871926785674),
|
||||
static_cast<T>(-3.513259738716989),
|
||||
static_cast<T>(-1.839177585570145),
|
||||
static_cast<T>(-0.9298342163526662),
|
||||
static_cast<T>(-0.426230503963466)};
|
||||
HWY_ALIGN static constexpr T arr_b[8] = {
|
||||
static_cast<T>(-0.0006967055197996615),
|
||||
static_cast<T>(-0.010315055591476996),
|
||||
static_cast<T>(-0.05367999021047822),
|
||||
static_cast<T>(-0.16943664192343108),
|
||||
static_cast<T>(-0.42437007298661206),
|
||||
static_cast<T>(-0.9556349519550872),
|
||||
static_cast<T>(-2.0824831112860647),
|
||||
static_cast<T>(-4.688832585616333)};
|
||||
HWY_ALIGN static constexpr T arr_c[8] = {
|
||||
static_cast<T>(0.220551955463595), static_cast<T>(0.5069204289218385),
|
||||
static_cast<T>(0.8423809865207907), static_cast<T>(1.1775629610724903),
|
||||
static_cast<T>(1.4909222917402543), static_cast<T>(1.757582383623199),
|
||||
static_cast<T>(1.9363640518503402), static_cast<T>(1.9985234759675707)};
|
||||
HWY_ALIGN static constexpr T arr_d[8] = {
|
||||
static_cast<T>(-5676.517069241468), static_cast<T>(-363.0662559912978),
|
||||
static_cast<T>(-60.61589604370584), static_cast<T>(-14.306713103378062),
|
||||
static_cast<T>(-2.725237489187118), static_cast<T>(0.7890752292798894),
|
||||
static_cast<T>(1.8089988725725492), static_cast<T>(1.9956027601801545)};
|
||||
static_cast<T>(3.9548607753775276), static_cast<T>(3.7450486139396544),
|
||||
static_cast<T>(3.253860706225495), static_cast<T>(2.4240814251983283),
|
||||
static_cast<T>(1.1565092321921886), static_cast<T>(-0.7540678688218365),
|
||||
static_cast<T>(-3.767209600467866), static_cast<T>(-9.357047249878605)};
|
||||
|
||||
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
|
||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
||||
hn::CappedTag<T, 8> d8;
|
||||
a = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_a)), idx);
|
||||
b = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_b)), idx);
|
||||
c = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_c)), idx);
|
||||
d_coef =
|
||||
hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_d)), idx);
|
||||
} else {
|
||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
||||
hn::FixedTag<T, 4> d4;
|
||||
a = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_a),
|
||||
hn::Load(d4, arr_a + 4), idx);
|
||||
b = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_b),
|
||||
hn::Load(d4, arr_b + 4), idx);
|
||||
c = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_c),
|
||||
hn::Load(d4, arr_c + 4), idx);
|
||||
d_coef = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_d),
|
||||
|
|
@ -150,66 +147,69 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
|||
const auto t6 = hn::Set(d, static_cast<T>(5.271780018997146));
|
||||
|
||||
// Start with highest index (7)
|
||||
a = hn::Set(d, static_cast<T>(-0.2132727031175401));
|
||||
c = hn::Set(d, static_cast<T>(-0.426230503963466));
|
||||
d_coef = hn::Set(d, static_cast<T>(1.9956027601801545));
|
||||
b = hn::Set(d, static_cast<T>(-4.688832585616333));
|
||||
c = hn::Set(d, static_cast<T>(1.9985234759675707));
|
||||
d_coef = hn::Set(d, static_cast<T>(-9.357047249878605));
|
||||
|
||||
// If y < t6 (idx 6)
|
||||
auto mask = hn::Lt(y, t6);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.4801959711368016)),
|
||||
a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.9298342163526662)),
|
||||
c);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-2.0824831112860647)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.9363640518503402)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.8089988725725492)), d_coef);
|
||||
mask, hn::Set(d, static_cast<T>(-3.767209600467866)), d_coef);
|
||||
|
||||
// If y < t5 (idx 5)
|
||||
mask = hn::Lt(y, t5);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-1.0464246812594584)),
|
||||
a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-1.839177585570145)), c);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.9556349519550872)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.757582383623199)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(0.7890752292798894)), d_coef);
|
||||
mask, hn::Set(d, static_cast<T>(-0.7540678688218365)), d_coef);
|
||||
|
||||
// If y < t4 (idx 4)
|
||||
mask = hn::Lt(y, t4);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-2.356433838423728)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-3.513259738716989)), c);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.42437007298661206)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.4909222917402543)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-2.725237489187118)), d_coef);
|
||||
mask, hn::Set(d, static_cast<T>(1.1565092321921886)), d_coef);
|
||||
|
||||
// If y < t3 (idx 3)
|
||||
mask = hn::Lt(y, t3);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-5.90191111348809)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-6.949871926785674)), c);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.16943664192343108)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.1775629610724903)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-14.306713103378062)), d_coef);
|
||||
mask, hn::Set(d, static_cast<T>(2.4240814251983283)), d_coef);
|
||||
|
||||
// If y < t2 (idx 2)
|
||||
mask = hn::Lt(y, t2);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-18.628915468855695)),
|
||||
a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-15.69264419046708)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-60.61589604370584)), d_coef);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.05367999021047822)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.8423809865207907)), c);
|
||||
d_coef = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(3.253860706225495)),
|
||||
d_coef);
|
||||
|
||||
// If y < t1 (idx 1)
|
||||
mask = hn::Lt(y, t1);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-96.9456723845743)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-49.14374182730444)), c);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.010315055591476996)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.5069204289218385)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-363.0662559912978)), d_coef);
|
||||
mask, hn::Set(d, static_cast<T>(3.7450486139396544)), d_coef);
|
||||
|
||||
// If y < t0 (idx 0)
|
||||
mask = hn::Lt(y, t0);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-1435.326650329326)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-316.5640994591445)), c);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.0006967055197996615)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.220551955463595)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-5676.517069241468)), d_coef);
|
||||
mask, hn::Set(d, static_cast<T>(3.9548607753775276)), d_coef);
|
||||
}
|
||||
|
||||
// Math: 0.5 * tanh(y/2) = (ay + 1.0)/(cy + d_coef)
|
||||
auto num = hn::MulAdd(a, y, hn::Set(d, static_cast<T>(1.0)));
|
||||
// Math: 0.5 * tanh(y/2) = (y + b)/(cy + d_coef)
|
||||
auto num = hn::Add(y, b);
|
||||
auto den = hn::MulAdd(c, y, d_coef);
|
||||
|
||||
auto approx = hn::Div(num, den);
|
||||
|
|
|
|||
Loading…
Reference in New Issue