first pass review: formatting

This commit is contained in:
Aman Gupta 2025-12-11 03:44:14 +01:00
parent 65f944bf18
commit a6dcaa5742
3 changed files with 27 additions and 26 deletions

View File

@ -246,7 +246,7 @@ static const char * cu_get_error_str(CUresult err) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
# define BLACKWELL_MMA_AVAILABLE
#endif
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
@ -713,8 +713,8 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
return 0;
}
const float sign = x < 0.0f ? -1.0f : 1.0f;
float ax = fabsf(x) * e;
const uint8_t sign_bit = x < 0.0f ? 0x8 : 0;
float ax = fabsf(x) * e;
// Positive LUT
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
@ -729,19 +729,14 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
#pragma unroll
for (int i = 1; i < 8; ++i) {
float err = fabsf(ax - pos_lut[i]);
const float err = fabsf(ax - pos_lut[i]);
if (err < best_err) {
best_err = err;
best_i = i;
}
}
// Positive codes: 0..7, negative: 8..15 (sign bit = MSB)
if (sign > 0.0f) {
return static_cast<uint8_t>(best_i); // 0..7
} else {
return static_cast<uint8_t>(best_i | 0x8); // 8..15
}
return static_cast<uint8_t>(best_i | sign_bit);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.

View File

@ -1,3 +1,4 @@
#include "common.cuh"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
@ -114,6 +115,8 @@ void ggml_cuda_mul_mat_q(
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@ -123,7 +126,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
@ -135,7 +138,7 @@ void ggml_cuda_mul_mat_q(
}
// Stride depends on quantization format
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values
:
@ -187,7 +190,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
@ -197,7 +200,7 @@ void ggml_cuda_mul_mat_q(
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;

View File

@ -218,8 +218,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
#ifdef BLACKWELL_MMA_AVAILABLE
case GGML_TYPE_MXFP4:
return MMQ_MMA_TILE_X_K_FP4;
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4;
#else
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
#endif
@ -784,7 +783,6 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(BLACKWELL_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
@ -833,7 +831,6 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
}
}
#endif
}
template <int mmq_x, int mmq_y>
@ -1026,7 +1023,7 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
const int * y_qs = (const int *) y + 2;
const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale pe rtile
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale per tile
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // per tile you would only have 1 scale per thread
// Block scale
@ -3253,7 +3250,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
#endif
#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
@ -3386,15 +3383,21 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(BLACKWELL_MMA_AVAILABLE)
constexpr bool use_native_mxfp4 = (type == GGML_TYPE_MXFP4);
#else
constexpr bool use_native_mxfp4 = false;
#endif // defined(BLACKWELL_MMA_AVAILBLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
constexpr size_t y_stride = type == GGML_TYPE_MXFP4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
constexpr size_t sz = use_native_mxfp4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
constexpr size_t y_stride = use_native_mxfp4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
constexpr int y_block_stride =
type == GGML_TYPE_MXFP4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
constexpr int y_block_stride = use_native_mxfp4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
:
(qk * sz / (4 * QK8_1 * sizeof(int))); // original formula for Q8_1
@ -3402,7 +3405,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{
const int * by0 =
type == GGML_TYPE_MXFP4 ?
use_native_mxfp4 ?
y + ncols_y * ((kb0 / 4) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq
:
y + ncols_y * (kb0 * y_block_stride); // original for Q8_1
@ -3422,7 +3425,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
{
const int * by0 =
type == GGML_TYPE_MXFP4 ?
use_native_mxfp4 ?
y + ncols_y * ((kb0 / 4) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq
:
y + ncols_y * (kb0 * y_block_stride +