diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3cfd5a4860..cdad39fc45 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -246,7 +246,7 @@ static const char * cu_get_error_str(CUresult err) { #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL # define BLACKWELL_MMA_AVAILABLE -#endif +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #define CP_ASYNC_AVAILABLE @@ -713,8 +713,8 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) return 0; } - const float sign = x < 0.0f ? -1.0f : 1.0f; - float ax = fabsf(x) * e; + const uint8_t sign_bit = x < 0.0f ? 0x8 : 0; + float ax = fabsf(x) * e; // Positive LUT static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f }; @@ -729,19 +729,14 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) #pragma unroll for (int i = 1; i < 8; ++i) { - float err = fabsf(ax - pos_lut[i]); + const float err = fabsf(ax - pos_lut[i]); if (err < best_err) { best_err = err; best_i = i; } } - // Positive codes: 0..7, negative: 8..15 (sign bit = MSB) - if (sign > 0.0f) { - return static_cast(best_i); // 0..7 - } else { - return static_cast(best_i | 0x8); // 8..15 - } + return static_cast(best_i | sign_bit); } // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 0844e475c2..a666858c8d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "mmq.cuh" #include "quantize.cuh" #include "mmid.cuh" @@ -114,6 +115,8 @@ void ggml_cuda_mul_mat_q( const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_CDNA(cc); + const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); @@ -123,7 +126,7 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) { + if (use_native_mxfp4) { quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); @@ -135,7 +138,7 @@ void ggml_cuda_mul_mat_q( } // Stride depends on quantization format - const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ? + const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values : @@ -187,7 +190,7 @@ void ggml_cuda_mul_mat_q( const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[2] / ts_src1; - if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) { + if (use_native_mxfp4) { quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); } else { @@ -197,7 +200,7 @@ void ggml_cuda_mul_mat_q( CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ? + const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) : ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 9b82247e07..319c7d08ee 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -218,8 +218,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; #ifdef BLACKWELL_MMA_AVAILABLE - case GGML_TYPE_MXFP4: - return MMQ_MMA_TILE_X_K_FP4; + case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4; #else case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; #endif @@ -784,7 +783,6 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(BLACKWELL_MMA_AVAILABLE) int * x_qs = (int *) x_tile; uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); @@ -833,7 +831,6 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e; } } -#endif } template @@ -1026,7 +1023,7 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res const int * y_qs = (const int *) y + 2; const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y - tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale pe rtile + tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale per tile uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // per tile you would only have 1 scale per thread // Block scale @@ -3253,7 +3250,7 @@ struct mmq_type_traits { #else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; -#endif +#endif // BLACKWELL_MMA_AVAILABLE static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -3386,15 +3383,21 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(BLACKWELL_MMA_AVAILABLE) + constexpr bool use_native_mxfp4 = (type == GGML_TYPE_MXFP4); +#else + constexpr bool use_native_mxfp4 = false; +#endif // defined(BLACKWELL_MMA_AVAILBLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / qk; float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; - constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); - constexpr size_t y_stride = type == GGML_TYPE_MXFP4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K; + constexpr size_t sz = use_native_mxfp4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); + constexpr size_t y_stride = use_native_mxfp4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K; - constexpr int y_block_stride = - type == GGML_TYPE_MXFP4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks) + constexpr int y_block_stride = use_native_mxfp4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks) : (qk * sz / (4 * QK8_1 * sizeof(int))); // original formula for Q8_1 @@ -3402,7 +3405,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); { const int * by0 = - type == GGML_TYPE_MXFP4 ? + use_native_mxfp4 ? y + ncols_y * ((kb0 / 4) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq : y + ncols_y * (kb0 * y_block_stride); // original for Q8_1 @@ -3422,7 +3425,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = - type == GGML_TYPE_MXFP4 ? + use_native_mxfp4 ? y + ncols_y * ((kb0 / 4) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq : y + ncols_y * (kb0 * y_block_stride +