CUDA: experimental native mxfp4 support for blackwell
This commit is contained in:
parent
c6f6e4f96a
commit
e214110ef7
|
|
@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
|
|||
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
|
||||
# 86 == RTX 3000, needs CUDA v11.1
|
||||
# 89 == RTX 4000, needs CUDA v11.8
|
||||
# 100 == Blackwell, needs CUDA v12.8, native FP4 tensor cores
|
||||
#
|
||||
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
|
||||
# XX-real == compile CUDA code as device code for this specific architecture
|
||||
|
|
@ -34,6 +35,10 @@ if (CUDAToolkit_FOUND)
|
|||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 100-real)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@
|
|||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||
#define GGML_CUDA_CC_BLACKWELL 1000
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
|
||||
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
|
||||
|
|
@ -243,6 +244,10 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#define AMPERE_MMA_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
|
||||
# define BLACKWELL_MMA_AVAILABLE
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
|
@ -313,6 +318,10 @@ static bool cp_async_available(const int cc) {
|
|||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
||||
}
|
||||
|
||||
static bool blackwell_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL;
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
return 64;
|
||||
|
|
@ -698,6 +707,41 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
|||
#endif // CUDART_VERSION >= 12050
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
// Handle exact zero early
|
||||
if (x == 0.0f) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const float sign = x < 0.0f ? -1.0f : 1.0f;
|
||||
float ax = fabsf(x) * e;
|
||||
|
||||
// Positive LUT
|
||||
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
|
||||
|
||||
// Saturate to max representable magnitude
|
||||
if (ax > pos_lut[7]) {
|
||||
ax = pos_lut[7];
|
||||
}
|
||||
|
||||
int best_i = 0;
|
||||
float best_err = fabsf(ax - pos_lut[0]);
|
||||
for (int i = 1; i < 8; ++i) {
|
||||
float err = fabsf(ax - pos_lut[i]);
|
||||
if (err < best_err) {
|
||||
best_err = err;
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Positive codes: 0..7, negative: 8..15 (sign bit = MSB)
|
||||
if (sign > 0.0f) {
|
||||
return static_cast<uint8_t>(best_i); // 0..7
|
||||
} else {
|
||||
return static_cast<uint8_t>(best_i | 0x8); // 8..15
|
||||
}
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
|
|
|
|||
|
|
@ -812,6 +812,25 @@ namespace ggml_cuda_mma {
|
|||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
float * Dxi = (float *) D.x;
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef TURING_MMA_AVAILABLE
|
||||
|
|
|
|||
|
|
@ -123,12 +123,23 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
// Stride depends on quantization format
|
||||
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
|
||||
(4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values
|
||||
:
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
const mmq_args args = {
|
||||
|
|
@ -175,12 +186,20 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
|
||||
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ struct block_q8_1_mmq {
|
|||
};
|
||||
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
||||
};
|
||||
|
||||
struct block_fp4_mmq {
|
||||
uint32_t d4[2]; // 1 8 bit (e8m0) scale per 32 values, packed LSB as d0-d1 in d4[0] and d4[1]
|
||||
int8_t qs[2 * 32]; // 128 values to 4 bit each (4 blocks)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
||||
|
||||
|
|
@ -191,6 +197,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_0)
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
|
|
@ -201,6 +208,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|||
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
||||
|
||||
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
switch (type) {
|
||||
|
|
@ -209,7 +217,12 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
case GGML_TYPE_MXFP4:
|
||||
return MMQ_MMA_TILE_X_K_FP4;
|
||||
#else
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
#endif
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
|
|
@ -229,6 +242,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||
|
||||
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
||||
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
|
||||
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K / 2
|
||||
|
||||
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
||||
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
|
||||
|
|
@ -761,6 +775,68 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kbx0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
int * x_qs = (int *) x_tile;
|
||||
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); // Same offset as original: 2*MMQ_TILE_NE_K
|
||||
|
||||
constexpr int nrows = 1;
|
||||
const int txi = threadIdx.x; // txi
|
||||
const int kbx = txi;
|
||||
|
||||
// TODO: only 8 threads of a warp at the moment for simplicity, use more threads
|
||||
if (txi >= 8) {
|
||||
return;
|
||||
}
|
||||
# pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nrows * nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
||||
|
||||
// Load packed FP4 data directly (no LUT dequantization)
|
||||
const int aux_q4_0 = get_int_b1(bxi->qs, 0);
|
||||
const int aux_q4_1 = get_int_b1(bxi->qs, 1);
|
||||
const int aux_q4_2 = get_int_b1(bxi->qs, 2);
|
||||
const int aux_q4_3 = get_int_b1(bxi->qs, 3);
|
||||
|
||||
const auto compress = [](const int x) -> int {
|
||||
uint16_t a = (x >> 24) & 0xF;
|
||||
uint16_t b = (x >> 16) & 0xF;
|
||||
uint16_t c = (x >> 8) & 0xF;
|
||||
uint16_t d = x & 0xF;
|
||||
|
||||
return (a << 12) | (b << 8) | (c << 4) | d;
|
||||
};
|
||||
|
||||
const int k0 = kbx * 4; // each block takes 4 bytes
|
||||
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4_1) << 16 | compress(aux_q4_0);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4_3) << 16 | compress(aux_q4_2);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4_1 >> 4) << 16 | compress(aux_q4_0 >> 4);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4_3 >> 4) << 16 | compress(aux_q4_2 >> 4);
|
||||
|
||||
if (txi % 2 == 0) {
|
||||
uint32_t e = bxi->e;
|
||||
bxi++;
|
||||
e |= (bxi->e << 8);
|
||||
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + txi / 2] = e;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
|
|
@ -930,6 +1006,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
||||
|
||||
// Match layout from load_tiles_mxfp4_fp4
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + MMQ_TILE_NE_K); // E8M0 scales at same offset as load
|
||||
const int * y_qs = (const int *) y + 2;
|
||||
const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y
|
||||
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale pe rtile
|
||||
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // per tile you would only have 1 scale per thread
|
||||
|
||||
// Block scale
|
||||
// Each thread has to point to a 4 byte scale value
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K / 2; k01 += QI8_0) {
|
||||
const int k0 = k00 / 2 + k01;
|
||||
|
||||
load_ldmatrix(A[n][k01 / QI8_0], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
||||
MMQ_MMA_TILE_X_K_FP4);
|
||||
|
||||
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
||||
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
scaleA[n][k01 / QI8_0] = *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / QI8_0);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K / 2; k01 += QI8_0) {
|
||||
tile_B B;
|
||||
uint32_t scaleB; // 2xN scales
|
||||
|
||||
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
||||
|
||||
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / QI8_0];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
tile_C C;
|
||||
|
||||
mma_block_scaled(C, A[n][k01 / QI8_0], B, scaleA[n][k01 / QI8_0], scaleB);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
|
|
@ -3102,8 +3248,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
#endif
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
|
|
@ -3240,13 +3391,24 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
|
||||
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
|
||||
constexpr size_t y_stride = type == GGML_TYPE_MXFP4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
|
||||
|
||||
constexpr int y_block_stride =
|
||||
type == GGML_TYPE_MXFP4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
|
||||
:
|
||||
(qk * sz / (4 * QK8_1 * sizeof(int))); // original formula for Q8_1
|
||||
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
||||
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
||||
|
||||
{
|
||||
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
const int * by0 =
|
||||
type == GGML_TYPE_MXFP4 ?
|
||||
y + ncols_y * ((kb0 / 4) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq
|
||||
:
|
||||
y + ncols_y * (kb0 * y_block_stride); // original for Q8_1
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
||||
for (int l0 = 0; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) {
|
||||
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
||||
|
||||
tile_y[l] = by0[l];
|
||||
|
|
@ -3260,9 +3422,14 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
__syncthreads();
|
||||
|
||||
{
|
||||
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
const int * by0 =
|
||||
type == GGML_TYPE_MXFP4 ?
|
||||
y + ncols_y * ((kb0 / 4) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq
|
||||
:
|
||||
y + ncols_y * (kb0 * y_block_stride +
|
||||
(int) (sz / sizeof(int))); // original for Q8_1 (advance by one block)
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
||||
for (int l0 = 0; l0 < mmq_x * y_stride; l0 += nwarps * warp_size) {
|
||||
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
||||
|
||||
tile_y[l] = by0[l];
|
||||
|
|
@ -3456,7 +3623,8 @@ static __global__ void mul_mat_q(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
|
||||
offset_y += (col_low + jt * mmq_x) * (sz / sizeof(int));
|
||||
offset_dst += it*mmq_y;
|
||||
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
|
|
@ -3523,7 +3691,8 @@ static __global__ void mul_mat_q(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
|
||||
offset_y += (col_low + jt * mmq_x) * (sz / sizeof(int));
|
||||
offset_dst += it*mmq_y;
|
||||
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
|
|
@ -3704,7 +3873,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
|||
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
||||
const size_t nbs_ids = mmq_x*sizeof(int);
|
||||
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
||||
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
|
||||
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,153 @@ static __global__ void quantize_q8_1(
|
|||
y[ib].ds = make_half2(d, sum);
|
||||
}
|
||||
|
||||
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
||||
const int32_t * __restrict__ ids,
|
||||
void * __restrict__ vy,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int ne1,
|
||||
const int ne2) {
|
||||
constexpr int vals_per_scale = 32;
|
||||
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32
|
||||
|
||||
// Each warp processes 2 adjacent blocks of 32 values (64 values total)
|
||||
const int64_t warp_start_offset = blockIdx.y * vals_per_warp;
|
||||
const int64_t i0_block0 = warp_start_offset + threadIdx.x; // First block: 0-31
|
||||
const int64_t i0_block1 = warp_start_offset + vals_per_scale + threadIdx.x; // Second block: 32-63
|
||||
|
||||
if (i0_block0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
|
||||
block_fp4_mmq * y = (block_fp4_mmq *) vy;
|
||||
|
||||
const int64_t block_fp4_mmq_size = 4 * QK_MXFP4; // 128 values
|
||||
|
||||
const int64_t ib0 =
|
||||
blockIdx.z * ((int64_t) gridDim.x * gridDim.y * vals_per_warp / block_fp4_mmq_size); // first block of channel
|
||||
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; // block index in channel
|
||||
const int64_t pair_idx_in_block =
|
||||
(warp_start_offset % block_fp4_mmq_size) / vals_per_warp; // 0-1: which pair of blocks within block_fp4_mmq
|
||||
|
||||
uint8_t e_packed[2];
|
||||
|
||||
// Process first block (0-31)
|
||||
{
|
||||
const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block0;
|
||||
const float xi = i0_block0 < ne00 ? x[global_src_pos] : 0.0f;
|
||||
|
||||
float amax = fabsf(xi);
|
||||
|
||||
// Reduce max across all 32 threads in the warp
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0;
|
||||
|
||||
float val = ggml_cuda_e8m0_to_fp32(e);
|
||||
float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val;
|
||||
|
||||
// Quantize: each thread processes 1 value
|
||||
uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
|
||||
|
||||
if (e == 0) {
|
||||
e = 127;
|
||||
}
|
||||
|
||||
// Pack 4 values into char2: threads 0,1,2,3 -> first char2, etc.
|
||||
const int lane_id = threadIdx.x % 4;
|
||||
const int group_id = threadIdx.x / 4;
|
||||
|
||||
// Use shuffle to gather values from 4 consecutive threads
|
||||
uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE);
|
||||
uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE);
|
||||
uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE);
|
||||
uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE);
|
||||
|
||||
char2 q;
|
||||
if (lane_id == 0) {
|
||||
q.x = (q1 << 4) | q0;
|
||||
q.y = (q3 << 4) | q2;
|
||||
|
||||
// Write to output: first block in pair uses positions based on pair_idx_in_block
|
||||
// Each pair has 2 blocks of 32 = 64 values = 16 char2 elements
|
||||
char2 * yqs2 = (char2 *) y[ib].qs;
|
||||
yqs2[pair_idx_in_block * 16 + group_id] = q;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
e_packed[0] = e;
|
||||
}
|
||||
}
|
||||
|
||||
// Process second block (32-63)
|
||||
{
|
||||
const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block1;
|
||||
const float xi = i0_block1 < ne00 ? x[global_src_pos] : 0.0f;
|
||||
|
||||
float amax = fabsf(xi);
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0;
|
||||
|
||||
float val = ggml_cuda_e8m0_to_fp32(e);
|
||||
float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val;
|
||||
|
||||
if (e == 0) {
|
||||
e = 127;
|
||||
}
|
||||
|
||||
uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
|
||||
|
||||
const int lane_id = threadIdx.x % 4;
|
||||
const int group_id = threadIdx.x / 4;
|
||||
|
||||
// Use shuffle to gather values from 4 consecutive threads
|
||||
uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE);
|
||||
uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE);
|
||||
uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE);
|
||||
uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE);
|
||||
|
||||
char2 q;
|
||||
if (lane_id == 0) {
|
||||
q.x = (q1 << 4) | q0;
|
||||
q.y = (q3 << 4) | q2;
|
||||
|
||||
// Write to output: second block in pair uses positions 8-15 within the pair
|
||||
char2 * yqs2 = (char2 *) y[ib].qs;
|
||||
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
e_packed[1] = e;
|
||||
}
|
||||
}
|
||||
|
||||
// Write packed exponents: d4[0-1] each stores 2 scales (for 2 blocks of 32)
|
||||
// pair_idx_in_block tells us which d4 entry to use (0-1)
|
||||
if (threadIdx.x == 0) {
|
||||
y[ib].d4[pair_idx_in_block] = (e_packed[1] << 8) | e_packed[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
static __global__ void quantize_mmq_q8_1(
|
||||
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
||||
|
|
@ -190,3 +337,26 @@ void quantize_mmq_q8_1_cuda(
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
[[maybe_unused]] const ggml_type type_src0,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); // Each warp processes 64 values
|
||||
|
||||
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4; // 64
|
||||
const int64_t block_num_y = (ne0 + vals_per_warp - 1) / vals_per_warp;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(32, 1, 1); // Warp size
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
|
|||
const float * x, const int32_t * ids, void * vy,
|
||||
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
ggml_type type_src0,
|
||||
int64_t ne00,
|
||||
int64_t s01,
|
||||
int64_t s02,
|
||||
int64_t s03,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3,
|
||||
cudaStream_t stream);
|
||||
|
|
|
|||
Loading…
Reference in New Issue