From d15731d2019fef5c77a392027c99f91d486665d7 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 25 Sep 2025 09:41:30 -0700 Subject: [PATCH] Used hn::BroadcastLane instead of Set(..., x.raw) PiperOrigin-RevId: 811386295 --- ops/ops-inl.h | 215 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 129 insertions(+), 86 deletions(-) diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 1593aa4..ec73f66 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -621,52 +621,116 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( }); } -template > +template , HWY_IF_V_SIZE_GT_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, + VF& sum5, VF& sum6, VF& sum7, VF& sum8, + VF& sum9, VF& sum10, VF& sum11, + VF& sum12, VF& sum13, VF& sum14, + VF& sum15) { + sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); + sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); + sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); + sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); + sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); + sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); + sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); + sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); + sum8 = hn::Mul(sum8, hn::BroadcastLane<8>(scale)); + sum9 = hn::Mul(sum9, hn::BroadcastLane<9>(scale)); + sum10 = hn::Mul(sum10, hn::BroadcastLane<10>(scale)); + sum11 = hn::Mul(sum11, hn::BroadcastLane<11>(scale)); + sum12 = hn::Mul(sum12, hn::BroadcastLane<12>(scale)); + sum13 = hn::Mul(sum13, hn::BroadcastLane<13>(scale)); + sum14 = hn::Mul(sum14, hn::BroadcastLane<14>(scale)); + sum15 = hn::Mul(sum15, hn::BroadcastLane<15>(scale)); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul16(DF df, const VF scale, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, + VF& sum5, VF& sum6, VF& sum7, VF& sum8, + VF& sum9, VF& sum10, VF& sum11, + VF& sum12, VF& sum13, VF& sum14, + VF& sum15) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, + VF& sum6, VF& sum7) { + sum0 = hn::Mul(sum0, hn::BroadcastLane<0>(scale)); + sum1 = hn::Mul(sum1, hn::BroadcastLane<1>(scale)); + sum2 = hn::Mul(sum2, hn::BroadcastLane<2>(scale)); + sum3 = hn::Mul(sum3, hn::BroadcastLane<3>(scale)); + sum4 = hn::Mul(sum4, hn::BroadcastLane<4>(scale)); + sum5 = hn::Mul(sum5, hn::BroadcastLane<5>(scale)); + sum6 = hn::Mul(sum6, hn::BroadcastLane<6>(scale)); + sum7 = hn::Mul(sum7, hn::BroadcastLane<7>(scale)); +} + +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void Mul8(DF df, const VF scale, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, + VF& sum6, VF& sum7) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 63)> HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) { - sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); - sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); - sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); - sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); - sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); - sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); - sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); - sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); - sum8 = hn::MulAdd(common, hn::Set(df, split.raw[8]), sum8); - sum9 = hn::MulAdd(common, hn::Set(df, split.raw[9]), sum9); - sum10 = hn::MulAdd(common, hn::Set(df, split.raw[10]), sum10); - sum11 = hn::MulAdd(common, hn::Set(df, split.raw[11]), sum11); - sum12 = hn::MulAdd(common, hn::Set(df, split.raw[12]), sum12); - sum13 = hn::MulAdd(common, hn::Set(df, split.raw[13]), sum13); - sum14 = hn::MulAdd(common, hn::Set(df, split.raw[14]), sum14); - sum15 = hn::MulAdd(common, hn::Set(df, split.raw[15]), sum15); + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); + sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); + sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); + sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); + sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); + sum8 = hn::MulAdd(common, hn::BroadcastLane<8>(split), sum8); + sum9 = hn::MulAdd(common, hn::BroadcastLane<9>(split), sum9); + sum10 = hn::MulAdd(common, hn::BroadcastLane<10>(split), sum10); + sum11 = hn::MulAdd(common, hn::BroadcastLane<11>(split), sum11); + sum12 = hn::MulAdd(common, hn::BroadcastLane<12>(split), sum12); + sum13 = hn::MulAdd(common, hn::BroadcastLane<13>(split), sum13); + sum14 = hn::MulAdd(common, hn::BroadcastLane<14>(split), sum14); + sum15 = hn::MulAdd(common, hn::BroadcastLane<15>(split), sum15); } -template > +template , HWY_IF_V_SIZE_LE_D(DF, 63)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd16( + DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, + VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7, VF& sum8, VF& sum9, + VF& sum10, VF& sum11, VF& sum12, VF& sum13, VF& sum14, VF& sum15) {} + +template , HWY_IF_V_SIZE_GT_D(DF, 31)> HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { - sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); - sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); - sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); - sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); - sum4 = hn::MulAdd(common, hn::Set(df, split.raw[4]), sum4); - sum5 = hn::MulAdd(common, hn::Set(df, split.raw[5]), sum5); - sum6 = hn::MulAdd(common, hn::Set(df, split.raw[6]), sum6); - sum7 = hn::MulAdd(common, hn::Set(df, split.raw[7]), sum7); + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); + sum4 = hn::MulAdd(common, hn::BroadcastLane<4>(split), sum4); + sum5 = hn::MulAdd(common, hn::BroadcastLane<5>(split), sum5); + sum6 = hn::MulAdd(common, hn::BroadcastLane<6>(split), sum6); + sum7 = hn::MulAdd(common, hn::BroadcastLane<7>(split), sum7); } +template , HWY_IF_V_SIZE_LE_D(DF, 31)> +HWY_INLINE HWY_MAYBE_UNUSED void MulAdd8(DF df, const VF common, const VF split, + VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& sum4, VF& sum5, VF& sum6, + VF& sum7) {} + template > HWY_INLINE HWY_MAYBE_UNUSED void MulAdd4(DF df, const VF common, const VF split, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { - sum0 = hn::MulAdd(common, hn::Set(df, split.raw[0]), sum0); - sum1 = hn::MulAdd(common, hn::Set(df, split.raw[1]), sum1); - sum2 = hn::MulAdd(common, hn::Set(df, split.raw[2]), sum2); - sum3 = hn::MulAdd(common, hn::Set(df, split.raw[3]), sum3); + sum0 = hn::MulAdd(common, hn::BroadcastLane<0>(split), sum0); + sum1 = hn::MulAdd(common, hn::BroadcastLane<1>(split), sum1); + sum2 = hn::MulAdd(common, hn::BroadcastLane<2>(split), sum2); + sum3 = hn::MulAdd(common, hn::BroadcastLane<3>(split), sum3); } // For an 8xNF tile of float values in 8xNF-lane registers, multiplies 8 rows @@ -706,22 +770,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( out13 = hn::Load(df, out + i + out_offsets[13]); out14 = hn::Load(df, out + i + out_offsets[14]); out15 = hn::Load(df, out + i + out_offsets[15]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); - out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); - out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); - out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); - out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); - out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); - out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); - out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); - out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + Mul16(df, scale, out0, out1, out2, out3, out4, out5, out6, out7, out8, + out9, out10, out11, out12, out13, out14, out15); VF x0 = hn::Load(df, v.Row(pos[0]) + i); MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11, out12, out13, out14, out15); @@ -773,14 +823,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( out5 = hn::Load(df, out + i + out_offsets[5]); out6 = hn::Load(df, out + i + out_offsets[6]); out7 = hn::Load(df, out + i + out_offsets[7]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + Mul8(df, scale, out0, out1, out2, out3, out4, out5, out6, out7); VF x0 = hn::Load(df, v.Row(pos[0]) + i); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); VF x1 = hn::Load(df, v.Row(pos[1]) + i); @@ -812,10 +855,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( out1 = hn::Load(df, out + i + out_offsets[1]); out2 = hn::Load(df, out + i + out_offsets[2]); out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); VF x0 = hn::Load(df, v.Row(pos[0]) + i); MulAdd4(df, x0, c0, out0, out1, out2, out3); VF x1 = hn::Load(df, v.Row(pos[1]) + i); @@ -878,22 +921,22 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out13 = hn::Load(df, out + i + out_offsets[13]); out14 = hn::Load(df, out + i + out_offsets[14]); out15 = hn::Load(df, out + i + out_offsets[15]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); - out8 = hn::Mul(out8, hn::Set(df, scale.raw[8])); - out9 = hn::Mul(out9, hn::Set(df, scale.raw[9])); - out10 = hn::Mul(out10, hn::Set(df, scale.raw[10])); - out11 = hn::Mul(out11, hn::Set(df, scale.raw[11])); - out12 = hn::Mul(out12, hn::Set(df, scale.raw[12])); - out13 = hn::Mul(out13, hn::Set(df, scale.raw[13])); - out14 = hn::Mul(out14, hn::Set(df, scale.raw[14])); - out15 = hn::Mul(out15, hn::Set(df, scale.raw[15])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); + out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); + out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); + out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); + out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); + out8 = hn::Mul(out8, hn::BroadcastLane<8>(scale)); + out9 = hn::Mul(out9, hn::BroadcastLane<9>(scale)); + out10 = hn::Mul(out10, hn::BroadcastLane<10>(scale)); + out11 = hn::Mul(out11, hn::BroadcastLane<11>(scale)); + out12 = hn::Mul(out12, hn::BroadcastLane<12>(scale)); + out13 = hn::Mul(out13, hn::BroadcastLane<13>(scale)); + out14 = hn::Mul(out14, hn::BroadcastLane<14>(scale)); + out15 = hn::Mul(out15, hn::BroadcastLane<15>(scale)); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd16(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11, out12, out13, out14, out15); @@ -923,14 +966,14 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out5 = hn::Load(df, out + i + out_offsets[5]); out6 = hn::Load(df, out + i + out_offsets[6]); out7 = hn::Load(df, out + i + out_offsets[7]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); - out4 = hn::Mul(out4, hn::Set(df, scale.raw[4])); - out5 = hn::Mul(out5, hn::Set(df, scale.raw[5])); - out6 = hn::Mul(out6, hn::Set(df, scale.raw[6])); - out7 = hn::Mul(out7, hn::Set(df, scale.raw[7])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); + out4 = hn::Mul(out4, hn::BroadcastLane<4>(scale)); + out5 = hn::Mul(out5, hn::BroadcastLane<5>(scale)); + out6 = hn::Mul(out6, hn::BroadcastLane<6>(scale)); + out7 = hn::Mul(out7, hn::BroadcastLane<7>(scale)); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd8(df, x0, c0, out0, out1, out2, out3, out4, out5, out6, out7); hn::Store(out0, df, out + i + out_offsets[0]); @@ -947,10 +990,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( out1 = hn::Load(df, out + i + out_offsets[1]); out2 = hn::Load(df, out + i + out_offsets[2]); out3 = hn::Load(df, out + i + out_offsets[3]); - out0 = hn::Mul(out0, hn::Set(df, scale.raw[0])); - out1 = hn::Mul(out1, hn::Set(df, scale.raw[1])); - out2 = hn::Mul(out2, hn::Set(df, scale.raw[2])); - out3 = hn::Mul(out3, hn::Set(df, scale.raw[3])); + out0 = hn::Mul(out0, hn::BroadcastLane<0>(scale)); + out1 = hn::Mul(out1, hn::BroadcastLane<1>(scale)); + out2 = hn::Mul(out2, hn::BroadcastLane<2>(scale)); + out3 = hn::Mul(out3, hn::BroadcastLane<3>(scale)); VF x0 = hn::Load(df, v.Row(pos) + i); MulAdd4(df, x0, c0, out0, out1, out2, out3); hn::Store(out0, df, out + i + out_offsets[0]);