diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 7dc36d4f8a..8f980c16b9 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +struct mma_instr; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvbf16ger2pp(acc, a, b); + } +}; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvf16ger2pp(acc, a, b); + } +}; + template -class tinyBLAS_BF16_PPC { +class tinyBLAS_HP16_PPC { public: - tinyBLAS_BF16_PPC(int64_t k, + tinyBLAS_HP16_PPC(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, @@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); - __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_2, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]); } } @@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B); for (int x = 0; x<2; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B); for (int x = 0; x<4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return tb.matmul(m, n); } #elif defined(__MMA__) - if ((k % 8)) - return false; - if(Btype == GGML_TYPE_BF16) { - tinyBLAS_BF16_PPC tb{ k, - (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - params->ith, params->nth}; - tb.matmul(m, n); - return true; + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_BF16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; } #elif defined(__riscv_zvfbfwma) #if LMUL == 1 @@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #endif return tb.matmul(m, n); } +#elif defined(__MMA__) + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_F16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } #endif return false; }