This commit is contained in:
Aman Gupta 2025-12-17 12:19:26 +08:00 committed by GitHub
commit 8d3a9670e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 401 additions and 17 deletions

View File

@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 100 == Blackwell, needs CUDA v12.8, native FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
# XX-real == compile CUDA code as device code for this specific architecture

View File

@ -50,6 +50,7 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_BLACKWELL 1000
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@ -246,6 +247,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
# define BLACKWELL_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@ -316,6 +321,10 @@ static bool cp_async_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool blackwell_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
@ -701,6 +710,33 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
// Handle exact zero early
if (x == 0.0f) {
return 0;
}
const uint8_t sign_bit = x < 0.0f ? 0x8 : 0;
float ax = fabsf(x) * e;
// Positive LUT
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
int best_i = 0;
float best_err = fabsf(ax - pos_lut[0]);
#pragma unroll
for (int i = 1; i < 8; ++i) {
const float err = fabsf(ax - pos_lut[i]);
if (err < best_err) {
best_err = err;
best_i = i;
}
}
return static_cast<uint8_t>(best_i | sign_bit);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)

View File

@ -847,6 +847,25 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
#ifdef BLACKWELL_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
float * Dxi = (float *) D.x;
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
#endif // BLACKWELL_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef TURING_MMA_AVAILABLE

View File

@ -1,3 +1,4 @@
#include "common.cuh"
#include "mmq.cuh"
#include "quantize.cuh"
#include "mmid.cuh"
@ -114,6 +115,8 @@ void ggml_cuda_mul_mat_q(
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| GGML_CUDA_CC_IS_CDNA(cc);
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@ -123,12 +126,23 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
}
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
// Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
:
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
const mmq_args args = {
@ -175,12 +189,19 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
}
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.

View File

@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@ -44,8 +45,15 @@ struct block_q8_1_mmq {
};
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
struct block_fp4_mmq {
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
};
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
}
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
#if defined(BLACKWELL_MMA_AVAILABLE)
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
#else
return MMQ_ITER_K;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIP)
#if defined(RDNA1)
@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@ -201,6 +218,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@ -209,6 +227,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
@ -229,6 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
//#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@ -761,6 +782,48 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
}
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kbx0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
int * x_qs = (int *) x_tile;
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int txi = threadIdx.x;
constexpr int threads_per_row = 16; // 16 blocks per row for 512 values (ITER_K=512)
constexpr int rows_per_warp = warp_size / threads_per_row;
const int kbx = txi % threads_per_row;
const int row_in_warp = txi / threads_per_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if (need_check) {
i = min(i, i_max);
}
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
const int k0 = kbx * 4;
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
// Load E8M0 scales: pack 2 consecutive scales into one uint32
if (kbx % 2 == 0) {
uint32_t e = bxi->e;
e |= ((bxi + 1)->e << 8);
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -930,6 +993,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
// Match layout from load_tiles_mxfp4_fp4
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
tile_A A[ntx][MMQ_TILE_NE_K / QI8_0];
uint32_t scaleA[ntx][MMQ_TILE_NE_K / QI8_0];
// Block scale
// Each thread has to point to a 4 byte scale value
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01 / QI8_0], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
MMQ_MMA_TILE_X_K_FP4);
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
scaleA[n][k01 / QI8_0] = *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / QI8_0);
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
tile_B B;
uint32_t scaleB; // 2xN scales
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / QI8_0];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma_block_scaled(C, A[n][k01 / QI8_0], B, scaleA[n][k01 / QI8_0], scaleB);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -3102,8 +3235,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
@ -3236,15 +3374,22 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
// blocks_per_mmq: number of qk-blocks per Y-block structure
// MXFP4: block_fp4_mmq holds 8 qk-blocks (256 values)
// Others: block_q8_1_mmq holds 4 qk-blocks (128 values)
constexpr int blocks_per_mmq = (type == GGML_TYPE_MXFP4) ? 8 : 4;
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
const int * by0 = y + ncols_y * (kb0 / blocks_per_mmq) * sz;
#pragma unroll
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
@ -3260,7 +3405,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
__syncthreads();
{
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
const int * by0 = y + ncols_y * ((kb0 / blocks_per_mmq) * sz + sz);
#pragma unroll
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
@ -3394,8 +3539,10 @@ static __global__ void mul_mat_q(
}
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
constexpr int ITER_K = get_iter_k(type);
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
constexpr int blocks_per_iter = ITER_K / qk;
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@ -3546,7 +3693,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
constexpr int ITER_K = get_iter_k(type);
constexpr int blocks_per_iter = ITER_K / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int nwarps = mmq_get_nwarps_device();
@ -3704,7 +3853,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}

View File

@ -47,6 +47,120 @@ static __global__ void quantize_q8_1(
y[ib].ds = make_half2(d, sum);
}
// Helper to compute E8M0 scale from amax using fast math
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
if (amax == 0.0f) {
return 127;
}
// log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585
// Use __log2f for fast approximate log2
const float log2_amax = __log2f(amax) - 2.5849625007211563f; // log2(6)
const int e_int = __float2int_rd(log2_amax) + 127; // floor + bias
return static_cast<uint8_t>(max(1, min(254, e_int)));
}
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int32_t * __restrict__ ids,
void * __restrict__ vy,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int ne1,
const int ne2) {
constexpr int vals_per_scale = 32;
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
const int warp_id = threadIdx.y;
const int lane_id_32 = threadIdx.x;
const int nwarps = blockDim.y;
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
if (warp_start_offset >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
block_fp4_mmq * y = (block_fp4_mmq *) vy;
const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
const int group_id = lane_id_32 / 4;
const int lane_in_group = lane_id_32 % 4;
const int base = group_id * 2;
char2 * yqs2 = (char2 *) y[ib].qs;
int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
uint8_t scales[2];
#pragma unroll
for (int b = 0; b < 2; ++b) {
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
float amax = fabsf(xi);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
const uint8_t e = compute_e8m0_scale(amax);
scales[b] = e;
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
#if CUDART_VERSION >= 12040
const float scaled_val = xi * inv_s;
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
}
#else
// Fallback: manual FP4 conversion using LUT
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
char2 q;
q.x = (q_hi_0 << 4) | q_lo_0;
q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
}
#endif // CUDART_VERSION >= 12040
}
if (lane_id_32 == 0) {
// Store 2 scales packed into 1 uint32
y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
}
}
template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
@ -190,3 +304,29 @@ void quantize_mmq_q8_1_cuda(
break;
}
}
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
[[maybe_unused]] const ggml_type type_src0,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}

View File

@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
const float * x, const int32_t * ids, void * vy,
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
ggml_type type_src0,
int64_t ne00,
int64_t s01,
int64_t s02,
int64_t s03,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
cudaStream_t stream);

View File

@ -10,6 +10,10 @@
#include <cuda_fp8.h>
#endif // CUDART_VERSION >= 12050
#if CUDART_VERSION >= 12040
#include <cuda_fp4.h>
#endif // CUDART_VERSION >= 12040
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH