use iter_k as 512, cleanup

This commit is contained in:
Aman Gupta 2025-12-12 09:43:41 +01:00
parent a1672f620b
commit 61c41a0d18
4 changed files with 74 additions and 88 deletions

View File

@ -719,11 +719,6 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
// Positive LUT // Positive LUT
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f }; static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
// Saturate to max representable magnitude
if (ax > pos_lut[7]) {
ax = pos_lut[7];
}
int best_i = 0; int best_i = 0;
float best_err = fabsf(ax - pos_lut[0]); float best_err = fabsf(ax - pos_lut[0]);

View File

@ -140,7 +140,7 @@ void ggml_cuda_mul_mat_q(
// Stride depends on quantization format // Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ? const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) / ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
: :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12; const int64_t s13 = ne12*s12;
@ -200,9 +200,8 @@ void ggml_cuda_mul_mat_q(
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }
const int64_t s12 = use_native_mxfp4 ? const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) : ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12; const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

View File

@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256 #define MMQ_ITER_K 256
#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8 #define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@ -46,13 +47,13 @@ struct block_q8_1_mmq {
}; };
struct block_fp4_mmq { struct block_fp4_mmq {
uint32_t d4[2]; // 1 8 bit (e8m0) scale per 32 values, packed LSB as d0-d1 in d4[0] and d4[1] uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
int8_t qs[2 * 32]; // 128 values to 4 bit each (4 blocks) int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
}; };
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_fp4_mmq) == 72, "Unexpected block_fp4_mmq size"); static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) { switch (type_x) {
@ -136,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
} }
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
#if defined(BLACKWELL_MMA_AVAILABLE)
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
#else
return MMQ_ITER_K;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
static constexpr __device__ int get_mmq_y_device() { static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
#if defined(RDNA1) #if defined(RDNA1)
@ -198,7 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
} }
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_0) #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@ -209,7 +218,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) { switch (type) {
@ -218,11 +227,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
#ifdef BLACKWELL_MMA_AVAILABLE // tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4;
#else
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
#endif
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@ -242,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2 //#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
static int mmq_get_granularity_host(const int mmq_x, const int cc) { static int mmq_get_granularity_host(const int mmq_x, const int cc) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) { if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@ -785,15 +792,14 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
int * x_qs = (int *) x_tile; int * x_qs = (int *) x_tile;
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int txi = threadIdx.x; const int txi = threadIdx.x;
// Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration constexpr int threads_per_row = 16; // 16 blocks per row for 512 values (ITER_K=512)
constexpr int threads_per_row = 8; // 8 blocks per row constexpr int rows_per_warp = warp_size / threads_per_row;
constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp const int kbx = txi % threads_per_row;
const int kbx = txi % threads_per_row; // block id 0-7 const int row_in_warp = txi / threads_per_row;
const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
@ -805,6 +811,7 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx; const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
const int k0 = kbx * 4; const int k0 = kbx * 4;
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16); memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
@ -1003,12 +1010,12 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
// Match layout from load_tiles_mxfp4_fp4 // Match layout from load_tiles_mxfp4_fp4
const int * x_qs = (const int *) x; const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + MMQ_TILE_NE_K); // E8M0 scales at same offset as load const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 2; const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y const uint32_t * y_sc = (const uint32_t *) y;
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale per tile tile_A A[ntx][MMQ_TILE_NE_K / QI8_0];
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // per tile you would only have 1 scale per thread uint32_t scaleA[ntx][MMQ_TILE_NE_K / QI8_0];
// Block scale // Block scale
// Each thread has to point to a 4 byte scale value // Each thread has to point to a 4 byte scale value
@ -1019,8 +1026,8 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
#pragma unroll #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K / 2; k01 += QI8_0) { for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
const int k0 = k00 / 2 + k01; const int k0 = k00 + k01;
load_ldmatrix(A[n][k01 / QI8_0], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, load_ldmatrix(A[n][k01 / QI8_0], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
MMQ_MMA_TILE_X_K_FP4); MMQ_MMA_TILE_X_K_FP4);
@ -1034,7 +1041,7 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
#pragma unroll #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K / 2; k01 += QI8_0) { for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
tile_B B; tile_B B;
uint32_t scaleB; // 2xN scales uint32_t scaleB; // 2xN scales
@ -3367,34 +3374,24 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(BLACKWELL_MMA_AVAILABLE) constexpr int ITER_K = get_iter_k(type);
constexpr bool use_native_mxfp4 = (type == GGML_TYPE_MXFP4); constexpr int blocks_per_iter = ITER_K / qk;
#else
constexpr bool use_native_mxfp4 = false;
#endif // defined(BLACKWELL_MMA_AVAILBLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
constexpr size_t sz = use_native_mxfp4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
constexpr size_t y_stride = use_native_mxfp4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
constexpr int y_block_stride = use_native_mxfp4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks) // blocks_per_mmq: number of qk-blocks per Y-block structure
: // MXFP4: block_fp4_mmq holds 8 qk-blocks (256 values)
(qk * sz / (4 * QK8_1 * sizeof(int))); // original formula for Q8_1 // Others: block_q8_1_mmq holds 4 qk-blocks (128 values)
constexpr int blocks_per_mmq = (type == GGML_TYPE_MXFP4) ? 8 : 4;
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{ {
const int * by0 = const int * by0 = y + ncols_y * (kb0 / blocks_per_mmq) * sz;
use_native_mxfp4 ?
y + ncols_y * ((kb0 / 4) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq
:
y + ncols_y * (kb0 * y_block_stride); // original for Q8_1
#pragma unroll #pragma unroll
for (int l0 = 0; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) { for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x; int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l]; tile_y[l] = by0[l];
@ -3408,14 +3405,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
__syncthreads(); __syncthreads();
{ {
const int * by0 = const int * by0 = y + ncols_y * ((kb0 / blocks_per_mmq) * sz + sz);
use_native_mxfp4 ?
y + ncols_y * ((kb0 / 4) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq
:
y + ncols_y * (kb0 * y_block_stride +
(int) (sz / sizeof(int))); // original for Q8_1 (advance by one block)
#pragma unroll #pragma unroll
for (int l0 = 0; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) { for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x; int l = l0 + threadIdx.y*warp_size + threadIdx.x;
tile_y[l] = by0[l]; tile_y[l] = by0[l];
@ -3547,8 +3539,10 @@ static __global__ void mul_mat_q(
} }
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
constexpr int ITER_K = get_iter_k(type);
const int64_t blocks_per_ne00 = ncols_x / qk; const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk; constexpr int blocks_per_iter = ITER_K / qk;
// kbc == k block continuous, current index in continuous ijk space. // kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@ -3609,8 +3603,7 @@ static __global__ void mul_mat_q(
__syncthreads(); __syncthreads();
} }
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_y += (col_low + jt * mmq_x) * (sz / sizeof(int));
offset_dst += it*mmq_y; offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@ -3677,8 +3670,7 @@ static __global__ void mul_mat_q(
__syncthreads(); __syncthreads();
} }
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
offset_y += (col_low + jt * mmq_x) * (sz / sizeof(int));
offset_dst += it*mmq_y; offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@ -3701,7 +3693,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int ncols_max) { const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device(); constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk; constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
const int64_t blocks_per_ne00 = ncols_x / qk; const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int nwarps = mmq_get_nwarps_device(); constexpr int nwarps = mmq_get_nwarps_device();

View File

@ -50,7 +50,7 @@ static __global__ void quantize_q8_1(
// Helper to compute E8M0 scale from amax using fast math // Helper to compute E8M0 scale from amax using fast math
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
if (amax == 0.0f) { if (amax == 0.0f) {
return 127; // Special case: use scale of 1.0 for zero input return 127;
} }
// log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585 // log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585
// Use __log2f for fast approximate log2 // Use __log2f for fast approximate log2
@ -59,6 +59,8 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
return static_cast<uint8_t>(max(1, min(254, e_int))); return static_cast<uint8_t>(max(1, min(254, e_int)));
} }
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int32_t * __restrict__ ids, const int32_t * __restrict__ ids,
void * __restrict__ vy, void * __restrict__ vy,
@ -70,9 +72,8 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int ne1, const int ne1,
const int ne2) { const int ne2) {
constexpr int vals_per_scale = 32; constexpr int vals_per_scale = 32;
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
// Multiple warps per block - each warp handles different data
const int warp_id = threadIdx.y; const int warp_id = threadIdx.y;
const int lane_id_32 = threadIdx.x; const int lane_id_32 = threadIdx.x;
@ -94,17 +95,17 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
block_fp4_mmq * y = (block_fp4_mmq *) vy; block_fp4_mmq * y = (block_fp4_mmq *) vy;
const int64_t block_fp4_mmq_size = 4 * QK_MXFP4; // 128 values const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
const int64_t ib0 = blockIdx.z * ((int64_t) gridDim.x * gridDim.y * nwarps * vals_per_warp / block_fp4_mmq_size); const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
const int group_id = lane_id_32 / 4; const int group_id = lane_id_32 / 4;
const int lane_in_group = lane_id_32 % 4; const int lane_in_group = lane_id_32 % 4;
const int base = group_id * 2; const int base = group_id * 2;
char2 * yqs2 = (char2 *) y[ib].qs; char2 * yqs2 = (char2 *) y[ib].qs;
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01; int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
uint8_t scales[2]; uint8_t scales[2];
@ -124,21 +125,17 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e)); const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
#if CUDART_VERSION >= 12040 #if CUDART_VERSION >= 12040
// Use hardware FP4 conversion: pre-scale and gather 4 floats, then convert+pack
const float scaled_val = xi * inv_s; const float scaled_val = xi * inv_s;
// Gather 4 scaled floats in the order matching __nv_fp4x4_e2m1 packing: const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
// float4(x,y,z,w) -> 16-bit with bits [3:0]=x, [7:4]=y, [11:8]=z, [15:12]=w const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
// This produces byte0 = (y<<4)|x, byte1 = (w<<4)|z const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); // -> low nibble byte 0 const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); // -> high nibble byte 0
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); // -> low nibble byte 1
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); // -> high nibble byte 1
if (lane_in_group == 0) { if (lane_in_group == 0) {
// Convert 4 floats -> packed 16-bit FP4 in one step
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3)); __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = reinterpret_cast<const char2&>(fp4_packed);
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
} }
#else #else
// Fallback: manual FP4 conversion using LUT // Fallback: manual FP4 conversion using LUT
@ -153,13 +150,14 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
char2 q; char2 q;
q.x = (q_hi_0 << 4) | q_lo_0; q.x = (q_hi_0 << 4) | q_lo_0;
q.y = (q_hi_1 << 4) | q_lo_1; q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q; yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
} }
#endif // CUDART_VERSION >= 12040 #endif // CUDART_VERSION >= 12040
} }
if (lane_id_32 == 0) { if (lane_id_32 == 0) {
y[ib].d4[pair_idx_in_block] = (scales[1] << 8) | scales[0]; // Store 2 scales packed into 1 uint32
y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
} }
} }
@ -320,15 +318,15 @@ void quantize_mmq_mxfp4_cuda(const float * x,
const int64_t ne2, const int64_t ne2,
const int64_t ne3, const int64_t ne3,
cudaStream_t stream) { cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); // Each warp processes 64 values GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
constexpr int nwarps = 8; constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4; // 64 values per warp constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp; // 512 values per block constexpr int vals_per_block = nwarps * vals_per_warp;
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1); // 32 threads x 8 warps = 256 threads per block const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
} }