diff --git a/gemma/ops.h b/gemma/ops.h index 8bb64f2..318af52 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -141,12 +141,55 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33); } +// Completes the tile by summing across the vectors, and adds the biases. +template , typename AddT> +HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03, + VF c10, VF c11, VF c12, VF c13, // + VF c20, VF c21, VF c22, VF c23, // + VF c30, VF c31, VF c32, VF c33, + const AddT* add, + float* HWY_RESTRICT tile_c, + size_t stride_c) { + // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. + // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in + // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is + // expensive, but only a fraction of the kColsA_RowsB/N FMAs. + float addon0 = hwy::ConvertScalarTo(add[0]); + tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00) + addon0; + float addon1 = hwy::ConvertScalarTo(add[1]); + tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01) + addon1; + float addon2 = hwy::ConvertScalarTo(add[2]); + tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02) + addon2; + float addon3 = hwy::ConvertScalarTo(add[3]); + tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03) + addon3; + if (kNumRows == 1) return; + + tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10) + addon0; + tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11) + addon1; + tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12) + addon2; + tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13) + addon3; + if (kNumRows == 2) return; + + tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20) + addon0; + tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21) + addon1; + tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22) + addon2; + tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23) + addon3; + if (kNumRows == 3) return; + + tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30) + addon0; + tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31) + addon1; + tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32) + addon2; + tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3; +} + // Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we // can iterate over both A and B with consecutive vector loads. kNumRows<=4. // Shared between parallelized and sequential (loop) callers. -template +template HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C, + const AddT* add, const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { @@ -203,21 +246,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); c10 = hn::MulAdd(a1, b0, c10); c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); c20 = hn::MulAdd(a2, b0, c20); c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); c30 = hn::MulAdd(a3, b0, c30); @@ -227,9 +270,16 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums( - d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); + if constexpr (kAdd) { + const AddT* tile_add = add + row_b_col_c; + StoreHorizontalSumsAdd( + d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_add, tile_c, stride_c); + } else { + StoreHorizontalSums( + d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); + } } #undef GEMMA_NATIVE_BF16 @@ -241,10 +291,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, #endif // As above, for MatT=bf16 -template +template HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C, + const AddT* add, const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { @@ -311,21 +362,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1); c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1); c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1); c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1); c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1); c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1); c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1); c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1); c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1); @@ -348,7 +399,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const VF a1 = hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab)); @@ -356,7 +407,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const VF a2 = hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab)); @@ -364,7 +415,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const VF a3 = hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab)); @@ -381,17 +432,27 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, #endif float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums( - df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, - c23, c30, c31, c32, c33, tile_c, stride_c); + if constexpr (kAdd) { + const AddT* tile_add = add + row_b_col_c; + StoreHorizontalSumsAdd( + df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_add, tile_c, stride_c); + } else { + StoreHorizontalSums( + df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); + } } // Same as above, but with mixed Mat types: (f32, compressed). -template +template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const size_t idx_tile, + float* HWY_RESTRICT C, + const AddT* add, + const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; @@ -450,21 +511,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); c10 = hn::MulAdd(a1, b0, c10); c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); c20 = hn::MulAdd(a2, b0, c20); c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); c30 = hn::MulAdd(a3, b0, c30); @@ -474,17 +535,27 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums( - d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, - c23, c30, c31, c32, c33, tile_c, stride_c); + if constexpr (kAdd) { + const AddT* tile_add = add + row_b_col_c; + StoreHorizontalSumsAdd( + d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_add, tile_c, stride_c); + } else { + StoreHorizontalSums( + d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); + } } -// Same as above, but with mixed Mat types: (bf16, compressed)). -template +// Same as above, but with mixed Mat types: (bf16, compressed). +template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const size_t idx_tile, + float* HWY_RESTRICT C, + const AddT* add, + const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; @@ -549,7 +620,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const V a1 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); @@ -557,7 +628,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const V a2 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); @@ -565,7 +636,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const V a3 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); @@ -576,18 +647,26 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, - c22, c23, c30, c31, c32, c33, tile_c, stride_c); + if constexpr (kAdd) { + const AddT* tile_add = add + row_b_col_c; + StoreHorizontalSumsAdd( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_add, tile_c, stride_c); + } else { + StoreHorizontalSums( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); + } } // Same as above, but with mixed Mat types: (f32, bf16). -template + typename MatTB, HWY_IF_BF16(MatTB), typename AddT> HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const size_t idx_tile, + float* HWY_RESTRICT C, const AddT* add, + const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; @@ -651,21 +730,21 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab); c10 = hn::MulAdd(a1, b0, c10); c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab); c20 = hn::MulAdd(a2, b0, c20); c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab); c30 = hn::MulAdd(a3, b0, c30); @@ -675,17 +754,25 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, - c22, c23, c30, c31, c32, c33, tile_c, stride_c); + if constexpr (kAdd) { + const AddT* tile_add = add + row_b_col_c; + StoreHorizontalSumsAdd( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_add, tile_c, stride_c); + } else { + StoreHorizontalSums( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); + } } // Same as above, but with mixed Mat types: (bf16, f32). -template +template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const size_t idx_tile, + float* HWY_RESTRICT C, const AddT* add, + const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { constexpr size_t kRegRows = 4; @@ -746,7 +833,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c01 = hn::MulAdd(a0, b1, c01); c02 = hn::MulAdd(a0, b2, c02); c03 = hn::MulAdd(a0, b3, c03); - if (kNumRows == 1) continue; + if constexpr (kNumRows == 1) continue; const VF a1 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); @@ -754,7 +841,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c11 = hn::MulAdd(a1, b1, c11); c12 = hn::MulAdd(a1, b2, c12); c13 = hn::MulAdd(a1, b3, c13); - if (kNumRows == 2) continue; + if constexpr (kNumRows == 2) continue; const VF a2 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); @@ -762,7 +849,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c21 = hn::MulAdd(a2, b1, c21); c22 = hn::MulAdd(a2, b2, c22); c23 = hn::MulAdd(a2, b3, c23); - if (kNumRows == 3) continue; + if constexpr (kNumRows == 3) continue; const VF a3 = hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); @@ -773,19 +860,27 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, - c20, c21, c22, c23, c30, c31, c32, c33, tile_c, - stride_c); + if constexpr (kAdd) { + const AddT* tile_add = add + row_b_col_c; + StoreHorizontalSumsAdd( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_add, tile_c, stride_c); + } else { + StoreHorizontalSums( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, + c30, c31, c32, c33, tile_c, stride_c); + } } + // Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k, // and kColsBC is 24k or 3k. Note: B is transposed (column-major). // NOTE that batch_size is the number of rows of A and C. // This function processes tiles in parallel with a work-stealing thread pool. -template -HWY_NOINLINE void MatMul_4x4_Batch( +template +HWY_NOINLINE void MatMul_4x4_Batch_Add( size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) { + OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) { // Process reg-sized tiles of C in parallel. We currently write C directly, // which touches more memory than fits in L3. TODO: add another level of loops // so that we finish one L3-sized piece of C at a time. @@ -810,51 +905,31 @@ HWY_NOINLINE void MatMul_4x4_Batch( HWY_ASSERT(num_rows > 0); switch (num_rows) { case 1: - GEMM_4x4_Tile<1, kColsA_RowsB>( - A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>( + A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); break; case 2: - GEMM_4x4_Tile<2, kColsA_RowsB>( - A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>( + A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); break; case 3: - GEMM_4x4_Tile<3, kColsA_RowsB>( - A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>( + A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); break; default: - GEMM_4x4_Tile<4, kColsA_RowsB>( - A, B, C, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>( + A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); } }); } -// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on -// ops_test across instruction sets. -template -HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b, - float* HWY_RESTRICT out) { - for (size_t i = 0; i < kM; ++i) { - for (size_t k = 0; k < kN; ++k) { - for (size_t j = 0; j < kK; ++j) { - const float a1 = hwy::ConvertScalarTo(a[i * kN + k]); - const float b1 = hwy::ConvertScalarTo(b[k * kK + j]); - out[i * kK + j] += a1 * b1; - } - } - } -} - -template -HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, - const SfpStream* HWY_RESTRICT b_sfp_stream, - float* HWY_RESTRICT out) { - const hn::ScalableTag d; - hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); - CompressTraits::Decompress(d, - /*in_capacity=*/0, b_sfp_stream, 0, - b.get(), kK * kN); - MatMulSlow(a, b.get(), out); +template +HWY_NOINLINE void MatMul_4x4_Batch( + size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, + OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) { + MatMul_4x4_Batch_Add( + batch_size, A, B, C, /*add=*/static_cast(nullptr), pool); } // Largely unoptimized; reordered innermost loops nets ~5-10X speedup on @@ -863,6 +938,7 @@ template HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, const MatTB* HWY_RESTRICT b, + const float* add, float* HWY_RESTRICT out) { for (size_t i = 0; i < batch_size; ++i) { for (size_t k = 0; k < kN; ++k) { @@ -872,6 +948,11 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, out[i * kK + j] += a1 * b1; } } + if (add != nullptr) { + for (size_t j = 0; j < kK; ++j) { + out[i * kK + j] += add[j]; + } + } } } @@ -881,13 +962,13 @@ template HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, const MatTB* HWY_RESTRICT b_compr, + const float* add, float* HWY_RESTRICT out) { const hn::ScalableTag d; hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); - CompressTraits::Decompress(d, - /*in_capacity=*/0, b_compr, 0, b.get(), + CompressTraits::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(), kK * kN); - MatMulSlowBatch(batch_size, a, b.get(), out); + MatMulSlowBatch(batch_size, a, b.get(), add, out); } HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index fe48833..b6b3405 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS HWY_SCALAR @@ -506,68 +507,80 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& a, } } -template void TestTiledBatchMatMul() { - fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu", kM, kN, kK); + fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu, add=%d ", kM, kN, kK, + kAdd); hwy::ThreadPool pool(3); std::unique_ptr> a = GenerateMatHeap(0, pool); std::unique_ptr> b = GenerateMatHeap(0, pool); + std::unique_ptr> add = + GenerateMatHeap(0, pool); std::unique_ptr> c_slow = GenerateZeroMatHeap(pool); - MatMulSlowBatch(kM, a->data(), b->data(), c_slow->data()); + MatMulSlowBatch(kM, a->data(), b->data(), + kAdd ? add->data() : nullptr, c_slow->data()); hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); std::unique_ptr> b_trans = GenerateTransposeMatHeap(0, pool); - MatMul_4x4_Batch(kM, a->data(), b_trans->data(), c.get(), pool); + if (kAdd) { + MatMul_4x4_Batch_Add(kM, a->data(), b_trans->data(), c.get(), + add->data(), pool); + } else { + MatMul_4x4_Batch(kM, a->data(), b_trans->data(), c.get(), pool); + } AssertClose(c_slow->data(), c.get(), kM * kK); } void TestAllTiledBatchMatMul() { // medium-sized square test - TestTiledBatchMatMul<512, 512, 512, float>(); - TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t>(); - TestTiledBatchMatMul<512, 512, 512, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<512, 512, 512, float, SfpStream>(); - TestTiledBatchMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, + SfpStream>(); // minimal non-square test - TestTiledBatchMatMul<35, 128, 32, float>(); - TestTiledBatchMatMul<34, 128, 32, hwy::bfloat16_t>(); - TestTiledBatchMatMul<33, 128, 32, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<33, 128, 32, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<31, 128, 32, float, SfpStream>(); - TestTiledBatchMatMul<29, 128, 32, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<4, 128, 8, float>(); - TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t>(); - TestTiledBatchMatMul<4, 128, 8, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<4, 128, 8, float, SfpStream>(); - TestTiledBatchMatMul<4, 128, 8, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<3, 128, 32, float>(); - TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t>(); - TestTiledBatchMatMul<3, 128, 32, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<3, 128, 32, float, SfpStream>(); - TestTiledBatchMatMul<3, 128, 32, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<2, 128, 16, float>(); - TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t>(); - TestTiledBatchMatMul<2, 128, 16, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<2, 128, 16, float, SfpStream>(); - TestTiledBatchMatMul<2, 128, 16, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<1, 128, 32, float>(); - TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t>(); - TestTiledBatchMatMul<1, 128, 32, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<1, 128, 32, float, SfpStream>(); - TestTiledBatchMatMul<1, 128, 32, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>(); + TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); + TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>(); + TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, + SfpStream>(); + TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>(); + TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>(); + TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>(); + TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>(); + TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>(); + TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, + SfpStream>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>(); // large-scale test // TODO(philculliton): investigate rounding issues with large matrices.