mirror of https://github.com/google/gemma.cpp.git
Use paralell blend chain path in FastSigmoid on architectures having >=32 registers
PiperOrigin-RevId: 886178215
This commit is contained in:
parent
50144738f1
commit
90f3de7f15
|
|
@ -146,66 +146,141 @@ HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
|||
const auto t5 = hn::Set(d, static_cast<T>(3.288402547357102));
|
||||
const auto t6 = hn::Set(d, static_cast<T>(5.271780018997146));
|
||||
|
||||
// Start with highest index (7)
|
||||
b = hn::Set(d, static_cast<T>(-4.688832585616333));
|
||||
c = hn::Set(d, static_cast<T>(1.9985234759675707));
|
||||
d_coef = hn::Set(d, static_cast<T>(-9.357047249878605));
|
||||
if constexpr (HWY_REGISTERS >= 32) {
|
||||
// Split into two parallel chains to reduce dependency latency.
|
||||
|
||||
// If y < t6 (idx 6)
|
||||
auto mask = hn::Lt(y, t6);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-2.0824831112860647)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.9363640518503402)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-3.767209600467866)), d_coef);
|
||||
// -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0)
|
||||
auto b_low = hn::Set(d, static_cast<T>(-0.16943664192343108)); // idx 3
|
||||
auto c_low = hn::Set(d, static_cast<T>(1.1775629610724903));
|
||||
auto d_low = hn::Set(d, static_cast<T>(2.4240814251983283));
|
||||
|
||||
// If y < t5 (idx 5)
|
||||
mask = hn::Lt(y, t5);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.9556349519550872)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.757582383623199)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.7540678688218365)), d_coef);
|
||||
auto mask = hn::Lt(y, t2);
|
||||
b_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.05367999021047822)), b_low);
|
||||
c_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(0.8423809865207907)), c_low);
|
||||
d_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.253860706225495)), d_low);
|
||||
|
||||
// If y < t4 (idx 4)
|
||||
mask = hn::Lt(y, t4);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.42437007298661206)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.4909222917402543)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.1565092321921886)), d_coef);
|
||||
mask = hn::Lt(y, t1);
|
||||
b_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.010315055591476996)), b_low);
|
||||
c_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(0.5069204289218385)), c_low);
|
||||
d_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.7450486139396544)), d_low);
|
||||
|
||||
// If y < t3 (idx 3)
|
||||
mask = hn::Lt(y, t3);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.16943664192343108)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.1775629610724903)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(2.4240814251983283)), d_coef);
|
||||
mask = hn::Lt(y, t0);
|
||||
b_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.0006967055197996615)), b_low);
|
||||
c_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(0.220551955463595)), c_low);
|
||||
d_low = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.9548607753775276)), d_low);
|
||||
|
||||
// If y < t2 (idx 2)
|
||||
mask = hn::Lt(y, t2);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.05367999021047822)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.8423809865207907)), c);
|
||||
d_coef = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(3.253860706225495)),
|
||||
d_coef);
|
||||
// -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4)
|
||||
auto b_high = hn::Set(d, static_cast<T>(-4.688832585616333)); // idx 7
|
||||
auto c_high = hn::Set(d, static_cast<T>(1.9985234759675707));
|
||||
auto d_high = hn::Set(d, static_cast<T>(-9.357047249878605));
|
||||
|
||||
// If y < t1 (idx 1)
|
||||
mask = hn::Lt(y, t1);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.010315055591476996)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.5069204289218385)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.7450486139396544)), d_coef);
|
||||
mask = hn::Lt(y, t6);
|
||||
b_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-2.0824831112860647)), b_high);
|
||||
c_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.9363640518503402)), c_high);
|
||||
d_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-3.767209600467866)), d_high);
|
||||
|
||||
// If y < t0 (idx 0)
|
||||
mask = hn::Lt(y, t0);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.0006967055197996615)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.220551955463595)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.9548607753775276)), d_coef);
|
||||
mask = hn::Lt(y, t5);
|
||||
b_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.9556349519550872)), b_high);
|
||||
c_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.757582383623199)), c_high);
|
||||
d_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.7540678688218365)), d_high);
|
||||
|
||||
mask = hn::Lt(y, t4);
|
||||
b_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.42437007298661206)), b_high);
|
||||
c_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.4909222917402543)), c_high);
|
||||
d_high = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.1565092321921886)), d_high);
|
||||
|
||||
// -- Merge the two chains
|
||||
auto merge_mask = hn::Lt(y, t3);
|
||||
b = hn::IfThenElse(merge_mask, b_low, b_high);
|
||||
c = hn::IfThenElse(merge_mask, c_low, c_high);
|
||||
d_coef = hn::IfThenElse(merge_mask, d_low, d_high);
|
||||
} else {
|
||||
// Start with highest index (7)
|
||||
b = hn::Set(d, static_cast<T>(-4.688832585616333));
|
||||
c = hn::Set(d, static_cast<T>(1.9985234759675707));
|
||||
d_coef = hn::Set(d, static_cast<T>(-9.357047249878605));
|
||||
|
||||
// If y < t6 (idx 6)
|
||||
auto mask = hn::Lt(y, t6);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-2.0824831112860647)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.9363640518503402)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-3.767209600467866)), d_coef);
|
||||
|
||||
// If y < t5 (idx 5)
|
||||
mask = hn::Lt(y, t5);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.9556349519550872)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.757582383623199)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-0.7540678688218365)), d_coef);
|
||||
|
||||
// If y < t4 (idx 4)
|
||||
mask = hn::Lt(y, t4);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.42437007298661206)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.4909222917402543)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.1565092321921886)), d_coef);
|
||||
|
||||
// If y < t3 (idx 3)
|
||||
mask = hn::Lt(y, t3);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.16943664192343108)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(1.1775629610724903)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(2.4240814251983283)), d_coef);
|
||||
|
||||
// If y < t2 (idx 2)
|
||||
mask = hn::Lt(y, t2);
|
||||
b = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.05367999021047822)),
|
||||
b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.8423809865207907)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.253860706225495)), d_coef);
|
||||
|
||||
// If y < t1 (idx 1)
|
||||
mask = hn::Lt(y, t1);
|
||||
b = hn::IfThenElse(mask,
|
||||
hn::Set(d, static_cast<T>(-0.010315055591476996)), b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.5069204289218385)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.7450486139396544)), d_coef);
|
||||
|
||||
// If y < t0 (idx 0)
|
||||
mask = hn::Lt(y, t0);
|
||||
b = hn::IfThenElse(mask,
|
||||
hn::Set(d, static_cast<T>(-0.0006967055197996615)), b);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(0.220551955463595)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(3.9548607753775276)), d_coef);
|
||||
}
|
||||
}
|
||||
|
||||
// Math: 0.5 * tanh(y/2) = (y + b)/(cy + d_coef)
|
||||
|
|
|
|||
Loading…
Reference in New Issue