From 3bb2fcc8567e139a0ef70b8d43f82a3130147c00 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Thu, 19 Feb 2026 11:58:53 +0530 Subject: [PATCH] llamafile: powerpc: add FP16 MMA path for Q4/Q8 matmul (#19709) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoid xvi8ger4pp signed→unsigned bias correction by dequantizing Q4/Q8 inputs to FP16 and using FP16×FP16→FP32 MMA. This removes post-processing overhead and improves performance. Performance Impact: 1.5 ~ 2x improvement in PP_Speed for Q4 and Q8 Models, measured with llama-bench and llama-batched-bench. Q8 Model: granite-4.0-h-micro-Q8_0.gguf (from huggingface) Q4 Model: Meta-Llama3-8b Q4 model (generated with llama-quantize from f32 model) llama-bench Q8 Model Results: model                                size     params backend    threads             test Base t/s Patch t/s granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp8         64.48 ± 4.72         73.99 ± 0.27 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp16         80.11 ± 0.32         112.53 ± 0.40 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp32         89.10 ± 0.27         152.95 ± 0.68 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp64         93.65 ± 0.25         187.83 ± 0.83 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           pp128         99.93 ± 0.02         201.32 ± 0.11 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           pp256         102.32 ± 0.40         208.32 ± 0.41 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           pp512         103.42 ± 0.40         209.98 ± 0.14 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           tg128         20.35 ± 0.01         19.57 ± 0.01 llama-bench Q4 Model Results: model                                size     params backend    threads             test               Base    t/s                Patch   t/s llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp8         34.77 ± 0.10         41.23 ± 0.08 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp16         40.81 ± 0.04         64.55 ± 0.15 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp32         44.65 ± 0.05         90.84 ± 0.22 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp64         47.49 ± 0.03         114.39 ± 0.11 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           pp128         49.29 ± 0.24         120.13 ± 0.19 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           pp256         49.77 ± 0.23         121.51 ± 0.11 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           pp512         49.89 ± 0.23         117.52 ± 0.10 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           tg128         13.40 ± 0.01         13.37 ± 0.00 Llama perplexity Results: Model Base Final PPL Estimate Patch Final PPL Estimate granite-4.0-h-micro-Q8_0 1.3862 +/- 0.04424 1.3868 +/- 0.04432 Meta-Llama3-8b Q4 1.3801 +/- 0.04116 1.3803 +/- 0.04116 Signed-off-by: Shalini.Salomi.Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm-ppc.h | 333 ------------ ggml/src/ggml-cpu/llamafile/sgemm.cpp | 662 ++++++++++++++++++------ 2 files changed, 507 insertions(+), 488 deletions(-) delete mode 100644 ggml/src/ggml-cpu/llamafile/sgemm-ppc.h diff --git a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h deleted file mode 100644 index a707868728..0000000000 --- a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +++ /dev/null @@ -1,333 +0,0 @@ -#pragma once - -typedef vector unsigned char vec_t; -typedef __vector_quad acc_t; - -template -class tinyBLAS_Q0_PPC { - public: - tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth); - - void matmul(int64_t m, int64_t n); - void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { - vec_t A_pack[mc*kc*2]; - vec_t B_pack[nc*kc*2]; - int comparray[mc*kc]; - constexpr bool is_Ablock_q4 = std::is_same_v; - int64_t ytiles = m / mc; - int64_t xtiles = n / nc; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) { - end = tiles; - } - for (int64_t job = start; job < end; ++job) { - int64_t ii = (job / xtiles) * mc; - int64_t jj = (job % xtiles) * nc; - for (int64_t kk = 0; kk < k; kk += kc) { - if constexpr(is_Ablock_q4) { - packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); - } else { - packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); - } - packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); - } - } - } - - private: - inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); - } - } - } - - inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); - *c_ptr += *((float*)&fin_res[idx+I]+J); - } - } - } - - template - inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { - vector signed int vec_C[4]; - vector float CA[4] = {0}; - vector float res[4] = {0}; - __builtin_mma_disassemble_acc(vec_C, ACC); - for (int i = 0; i < 4; i++) { - CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); - res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); - fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); - } - } - - inline void process_q4_elements(vector signed char (&c)[2], int* ca) { - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; - c[0] = vec_and(c[1], lowMask); - c[1] = vec_sr(c[1], v4); - c[0] = vec_sub(c[0], v8); - c[1] = vec_sub(c[1], v8); - vsum = vec_sum4s(c[0], vsum); - vsum2 = vec_sum4s(c[1], vsum2); - vsum = vec_add(vsum, vsum2); - *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - V2 t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - t1 = vec_perm(s1, s2, swiz1); - t2 = vec_perm(s1, s2, swiz2); - t3 = vec_perm(s3, s4, swiz1); - t4 = vec_perm(s3, s4, swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - } - - template - inline void kernel(int64_t ii, int64_t jj) { - if constexpr(RM == 4 && RN == 8) { - KERNEL_4x8(ii,jj); - } else if constexpr(RM == 8 && RN == 4) { - KERNEL_8x4(ii,jj); - } else if constexpr(RM == 8 && RN == 8) { - KERNEL_8x8(ii,jj); - } else { - assert(false && "RN/RM values not supported"); - } - } - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray); - template - void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); - void KERNEL_4x8(int64_t ii, int64_t jj); - void KERNEL_8x4(int64_t ii, int64_t jj); - void KERNEL_8x8(int64_t ii, int64_t jj); - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); - template - void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); - - void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ - for (int I = 0; I<8; I++) { - float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); - *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); - } - } - } - - inline void process_q8_elements(const int8_t *qs, int *ca) { - vector signed char c1 = vec_xl(0, qs); - vector signed char c2 = vec_xl(16, qs); - vector signed int vsum1 = {0}; - vector signed int vsum2 = {0}; - vsum1 = vec_sum4s(c1, vsum1); - vsum2 = vec_sum4s(c2, vsum2); - vector signed int vsum = vec_add(vsum1, vsum2); - *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { - int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; - __vector_pair arr[8]; - VB c[8][2] = {0}; - VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - j = (rows >> 3); - int index = 0; - if (j > 0) { - do { - for (int it = 0; it < 8; it++) - aoffsets[it] = aoffset + it*lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); - c1[it] = c[it][0]; - c2[it] = c[it][1]; - if (comparray){ - process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); - } - } - vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); - vecOffset += 256; - } - j--; - index += 8*kc; - } while(j > 0); - } - - } - - void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { - int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - int index = 0; - j = (rows >> 3); - if (j > 0) { - do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs)); - c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs)); - c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs)); - c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs)); - c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs)); - c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs)); - c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs)); - c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs)); - - process_q4_elements(c1, &comparray[index + 8*blk+0]); - process_q4_elements(c2, &comparray[index + 8*blk+1]); - process_q4_elements(c3, &comparray[index + 8*blk+2]); - process_q4_elements(c4, &comparray[index + 8*blk+3]); - process_q4_elements(c5, &comparray[index + 8*blk+4]); - process_q4_elements(c6, &comparray[index + 8*blk+5]); - process_q4_elements(c7, &comparray[index + 8*blk+6]); - process_q4_elements(c8, &comparray[index + 8*blk+7]); - vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); - vecOffset += 256; - } - j--; - index += 8*kc; - } while (j > 0); - } - } - - void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { - acc_t acc[8]; - for (int i = 0; i < mc ; i += 8) { - for (int j = 0; j < nc; j += 8) { - vector float fin_res[16] = {0}; - vector float vs[16] = {0}; - for (int64_t kk = 0; kk < kc; kk+=2) { - for (int x = 0; x < 8; x++) { - __builtin_mma_xxsetaccz(&acc[x]); - } - int A_block_idx = (i/8)*(16*kc) + kk*16; - int B_block_idx = (j/8)*(16*kc)+ kk*16; - vec_t *A_block = &vec_A[A_block_idx]; - vec_t *B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk, vs); - int c_index = (i/8)*(8*kc)+ kk*8; - int* c_block = &comparray[c_index]; - compute(&acc[0], 0, 0, c_block, vs, fin_res); - compute(&acc[1], 4, 4, c_block, vs, fin_res); - compute(&acc[2], 0, 8, c_block, vs, fin_res); - compute(&acc[3], 4, 12, c_block, vs, fin_res); - - A_block_idx = (i/8)*(16*kc) + (kk+1)*16; - B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; - A_block = &vec_A[A_block_idx]; - B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk+1, vs); - c_index = (i/8)*(8*kc)+ (kk+1)*8; - c_block = &comparray[c_index]; - compute(&acc[4], 0, 0, c_block, vs, fin_res); - compute(&acc[5], 4, 4, c_block, vs, fin_res); - compute(&acc[6], 0, 8, c_block, vs, fin_res); - compute(&acc[7], 4, 12, c_block, vs, fin_res); - - } - if (l == 0) { - save_res(ii+i, jj+j, 0, fin_res); - save_res(ii+i+4, jj+j, 4, fin_res); - save_res(ii+i, jj+j+4, 8, fin_res); - save_res(ii+i+4, jj+j+4, 12, fin_res); - } else { - add_save_res(ii+i, jj+j, 0, fin_res); - add_save_res(ii+i+4, jj+j, 4, fin_res); - add_save_res(ii+i, jj+j+4, 8, fin_res); - add_save_res(ii+i+4, jj+j+4, 12, fin_res); - } - } - } - } - - const TA *const A; - const block_q8_0 *const B; - float *C; - const int64_t k; - int64_t kc; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 8f980c16b9..da412fd009 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); } #endif #if defined(__MMA__) -#include "sgemm-ppc.h" +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC { packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); - mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_1, vec_A[x+4], vec_B[x]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC { const int nth; }; - template - tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA * A, int64_t lda, + const block_q8_0 * B, int64_t ldb, + float * C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - kc = 64; } - template - void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) { - int mc = 64; int nc = 64; - if (n % 8 == 0 && n < nc) { - nc = n; - mc = 32 ; - kc = 32; + void matmul(int64_t m, int64_t n) { + const int64_t mc = 64; + const int64_t kc = 64; + int64_t nc = 64; + int64_t n_aligned = 0; + if (n % 64 == 0) { + n_aligned = n; + } else if (n == 4) { + n_aligned = 4; + } else if (n < 64) { + n_aligned = (n / 8) * 8; + } else { + n_aligned = (n / 64) * 64; } - const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0); - if (is_aligned) { - this->matmul_tiled_q0(m, n, mc, nc, kc); + + if (n_aligned > 0) { + if (n_aligned % 64 == 0) nc = 64; + else if (n_aligned == n) nc = n; + else if (n_aligned % 32 == 0) nc = 32; + else if (n_aligned % 24 == 0) nc = 24; + else if (n_aligned % 16 == 0) nc = 16; + else nc = 8; + } + bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + if (can_use_tiled) { + matmul_tiled(m, n_aligned, mc, nc, kc); + if (n > n_aligned) { + mnpack(0, m, n_aligned, n); + } } else { mnpack(0, m, 0, n); } } - template - template - void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + private: + inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J); + } + } + } + + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J); + } + } + } + + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I); + *c_ptr += *((float *)&vec_C[I] + J); + } + } + } + + template + inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]); + } + } + + inline void process_q4_elements(vector signed char (&c)[2], int * ca) { + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) { + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 16); + vec_xst(t7, 0, vecOffset + 32); + vec_xst(t8, 0, vecOffset + 48); + } + + inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) { + const vector signed char lowMask = vec_splats((signed char)0x0F); + const vector signed char v8 = vec_splats((signed char)0x08); + const vector unsigned char v4 = vec_splats((unsigned char)4); + lo = vec_and(packed, lowMask); + hi = vec_sr(packed, v4); + lo = vec_sub(lo, v8); + hi = vec_sub(hi, v8); + } + + inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + for (int i = 0; i < 4; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) { + vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16)); + } + } + + static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) { + vector signed short i16_hi = vec_unpackh(raw); + vector signed short i16_lo = vec_unpackl(raw); + + vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0); + vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0); + vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0); + vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0); + out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale)); + out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale)); + } + + void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + for (int i = 0; i < rows; i += 8) { + const block_q4_0 * rows_base[8]; + for (int r = 0; r < 8; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[8][4]; + for (int r = 0; r < 8; r++) { + const block_q4_0 * current_blk = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); + vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs)); + vector signed char c1, c2; + unpack_q4_to_q8(v_qs, c1, c2); + convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int c = 0; c < 4; c++) { + vector unsigned char c_arr[8]; + for (int r = 0; r < 8; r++) { + c_arr[r] = (vector unsigned char)hp_res[r][c]; + } + vector_permute_store_fp16((vec_t *)c_arr, vecOffset); + vecOffset += 128; + } + } + } + } + + template + static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + for (int i = 0; i < rows; i += chunk_size) { + const block_q8_0 * rows_base[chunk_size]; + for (int r = 0; r < chunk_size; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[chunk_size][4]; + for (int r = 0; r < chunk_size; r++) { + const block_q8_0 * b = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d)); + vector signed char c[2]; + __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs); + __builtin_vsx_disassemble_pair(c, & pair); + convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int col = 0; col < 4; col++) { + if constexpr (chunk_size == 8) { + vec_t t[8]; + t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1); + t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2); + t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1); + t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2); + + vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48)); + vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64)); + vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80)); + vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96)); + vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112)); + vecOffset += 128; + } else { + vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + + vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48)); + vecOffset += 64; + } + } + } + } + } + + void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + if (rows == 4) { + pack_q8_block<4>(a, lda, rows, blocks, vec); + } else { + pack_q8_block<8>(a, lda, rows, blocks, vec); + } + } + + template + void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array & comparray) { int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + TA * aoffset = NULL; + int8_t * vecOffset = NULL; + TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL; + TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL; vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { @@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC { c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); - process_q4_elements(c5, &comparray[4]); - process_q4_elements(c6, &comparray[5]); - process_q4_elements(c7, &comparray[6]); - process_q4_elements(c8, &comparray[7]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); + process_q4_elements(c5, & comparray[4]); + process_q4_elements(c6, & comparray[5]); + process_q4_elements(c7, & comparray[6]); + process_q4_elements(c8, & comparray[7]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC { c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC { case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC { } } - template template - void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) { int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; + block_q8_0 * aoffset = NULL; + VA * vecOffset = NULL; + block_q8_0 * aoffsets[8]; __vector_pair arr[8]; VB c[8][2] = {0}; VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { aoffsets[0] = aoffset; for (int it = 1; it < 8; it++) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip); for (int it = 0; it < 8; it++) aoffsets[it] += lda; vecOffset += 256; @@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC { if (i > 0) { do { for (int it = 0; it < 4; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 4; it++) { aoffsets[it] += lda; } @@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC { if (rows & 3) { aoffsets[0] = aoffset; for (int it = 1; it < 3; it++ ) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); - __builtin_vsx_disassemble_pair(c[2], &arr[2]); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], & arr[2]); c1[2] = c[2][0]; c2[2] = c[2][1]; - case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); - __builtin_vsx_disassemble_pair(c[1], &arr[1]); + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], & arr[1]); c1[1] = c[1][0]; c2[1] = c[1][1]; - case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); - __builtin_vsx_disassemble_pair(c[0], &arr[0]); + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], & arr[0]); c1[0] = c[0][0]; c2[0] = c[0][1]; break; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 3; it++) aoffsets[it] += lda; vecOffset += 128; @@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC { } } - template - void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); @@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC { } - template - void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) { + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]); } for (int I = 0; I<4; I++) { for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 4; i++) { comparray[i] = 0; int ca = 0; @@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 0, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 0, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii, jj+4, 4, fin_res); + save_res(ii, jj + 4, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) { + void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < 8; I++) { + for (int J = 0; J < 4; J++) { + *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); + save_res(ii + 4, jj, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) { + void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; acc_t acc_4, acc_5, acc_6, acc_7; @@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC { vector float vs[16] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); - __builtin_mma_xxsetaccz(&acc_2); - __builtin_mma_xxsetaccz(&acc_3); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); + __builtin_mma_xxsetaccz(& acc_2); + __builtin_mma_xxsetaccz(& acc_3); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); - __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]); + __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + for (int I = 0; I < 8 ; I++) { + for (int J = 0; J < 4; J++) { + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); - compute(&acc_2, 0, 8, comparray, vs, fin_res); - compute(&acc_3, 4, 12, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_2, 0, 8, comparray, vs, fin_res); + compute(& acc_3, 4, 12, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); - save_res(ii, jj+4, 8, fin_res); - save_res(ii+4, jj+4, 12, fin_res); + save_res(ii + 4, jj, 4, fin_res); + save_res(ii, jj + 4, 8, fin_res); + save_res(ii + 4, jj + 4, 12, fin_res); } - template - void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) { + acc_t acc[8]; + for (int i = 0; i < mc ; i += 16) { + for (int j = 0; j < nc; j += 8) { + int A0_base = (i / 16) * (2 * 32 * kc); + int B0_base = (j / 8) * (32 * kc); + for (int x = 0; x < 8; x++) { + __builtin_mma_xxsetaccz(&acc[x]); + } + for (int64_t kk = 0; kk < kc; kk++) { + int A0_block_idx = A0_base + kk * 32; + int B0_block_idx = B0_base + kk * 32; + int A1_block_idx = A0_block_idx + 32 * kc; + int B1_block_idx = B0_block_idx + 32 * kc; + vec_t * A0_block = & vec_A[A0_block_idx]; + vec_t * B0_block = & vec_B[B0_block_idx]; + vec_t * A1_block = & vec_A[A1_block_idx]; + for (int it = 0; it < 4; it++) { + for (int x = 0; x < 4; x++) { + __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]); + __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + } + } + } + if (l == 0) { + save_acc(& acc[0], ii + i, jj + j); + save_acc(& acc[1], ii + i, jj + j + 4); + save_acc(& acc[2], ii + i + 4, jj + j); + save_acc(& acc[3], ii + i + 4, jj + j + 4); + save_acc(& acc[4], ii + i + 8, jj + j); + save_acc(& acc[5], ii + i + 8, jj + j + 4); + save_acc(& acc[6], ii + i + 12, jj + j); + save_acc(& acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(& acc[0], ii + i, jj + j); + add_save_acc(& acc[1], ii + i, jj + j + 4); + add_save_acc(& acc[2], ii + i + 4, jj + j); + add_save_acc(& acc[3], ii + i + 4, jj + j + 4); + add_save_acc(& acc[4], ii + i + 8, jj + j); + add_save_acc(& acc[5], ii + i + 8, jj + j + 4); + add_save_acc(& acc[6], ii + i + 12, jj + j); + add_save_acc(& acc[7], ii + i + 12, jj + j + 4); + } + } + } + } + + void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { + vec_t A_pack[mc * kc * 4]; + vec_t B_pack[nc * kc * 4]; + constexpr bool is_Ablock_q4 = std::is_same_v; + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + if constexpr(is_Ablock_q4) { + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } else { + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + } + } + } + + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC { vector float fin_res[4] = {0}; vector float vs[4] = {0}; vector float CA[4] = {0}; - __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value - __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value for (int l = 0; l < k; l++) { - __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_mma_xxsetaccz(&acc_0); + __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(& acc_0); if (isAblock_q4) { - packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); - for(int x = 0; x < 8; x+=4) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + packNormal((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true); + for (int x = 0; x < 8; x += 4) { + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]); } - for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } - __builtin_mma_disassemble_acc(vec_C, &acc_0); + __builtin_mma_disassemble_acc(vec_C, & acc_0); if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < RM; i++) { comparray[i] = 0; int ca = 0; @@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC { } } - template + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + assert(false && "RN/RM values not supported"); + } + } + template - NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC { for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - this->kernel(ii, jj); + kernel(ii, jj); } } - -template class tinyBLAS_Q0_PPC; -template class tinyBLAS_Q0_PPC; + const TA * const A; + const block_q8_0 * const B; + float * C; + const int64_t k; + int64_t kc; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; class tinyBLAS_PPC { public: