[MUSA] enable fp16/fast_fp16/bf16_mma on PH1 (#17551)

* [MUSA] enable fp16/fast_fp16/bf16_mma on PH1

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Update ggml/src/ggml-cuda/fattn-vec.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/fattn-vec.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/fattn-tile.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Address review comments

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
R0CKSTAR 2025-11-28 21:08:29 +08:00 committed by GitHub
parent 2e7ef98f18
commit c6f7a423c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 21 additions and 14 deletions

View File

@ -84,12 +84,12 @@
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD #define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB # define GGML_CUDA_USE_CUB
@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_USE_VMM #define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE #define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE #define FAST_FP16_AVAILABLE
@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) {
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
static bool fp16_available(const int cc) { static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
} }
static bool fast_fp16_available(const int cc) { static bool fast_fp16_available(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) || return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
} }
// To be used for feature selection of external libraries, e.g. cuBLAS. // To be used for feature selection of external libraries, e.g. cuBLAS.
@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) {
} }
static bool bf16_mma_hardware_available(const int cc) { static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
} }
static bool fp32_mma_hardware_available(const int cc) { static bool fp32_mma_hardware_available(const int cc) {

View File

@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
} }
} }
} }
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
nb12, nb13);
} }
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
ne00n = ne00; ne00n = ne00;
ne01n = ne01; ne01n = ne01;
ne02n = ne02; ne02n = ne02;
} else if (nb00 > nb02) { } else {
ne00n = ne00; ne00n = ne00;
ne01n = ne01*ne02; ne01n = ne01*ne02;
ne02n = 1; ne02n = 1;

View File

@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f; float KQ_sum_add = 0.0f;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val; KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val; tmp[i0/(np*warp_size)][jc1] = val;

View File

@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0; tmp_q_i32[i] = 0;
} }
} }
@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum; KQ_reg[j] = sum;
} }
} }

View File

@ -889,8 +889,8 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else #else
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
mma(D16[0], A16[0], B); mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B); mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE