first pass review: formatting
This commit is contained in:
parent
65f944bf18
commit
a6dcaa5742
|
|
@ -246,7 +246,7 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
|
||||
# define BLACKWELL_MMA_AVAILABLE
|
||||
#endif
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
|
||||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
|
|
@ -713,8 +713,8 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
|
|||
return 0;
|
||||
}
|
||||
|
||||
const float sign = x < 0.0f ? -1.0f : 1.0f;
|
||||
float ax = fabsf(x) * e;
|
||||
const uint8_t sign_bit = x < 0.0f ? 0x8 : 0;
|
||||
float ax = fabsf(x) * e;
|
||||
|
||||
// Positive LUT
|
||||
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
|
||||
|
|
@ -729,19 +729,14 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
|
|||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 8; ++i) {
|
||||
float err = fabsf(ax - pos_lut[i]);
|
||||
const float err = fabsf(ax - pos_lut[i]);
|
||||
if (err < best_err) {
|
||||
best_err = err;
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Positive codes: 0..7, negative: 8..15 (sign bit = MSB)
|
||||
if (sign > 0.0f) {
|
||||
return static_cast<uint8_t>(best_i); // 0..7
|
||||
} else {
|
||||
return static_cast<uint8_t>(best_i | 0x8); // 8..15
|
||||
}
|
||||
return static_cast<uint8_t>(best_i | sign_bit);
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "common.cuh"
|
||||
#include "mmq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
|
@ -114,6 +115,8 @@ void ggml_cuda_mul_mat_q(
|
|||
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|
||||
|| GGML_CUDA_CC_IS_CDNA(cc);
|
||||
|
||||
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
|
|
@ -123,7 +126,7 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
|
|
@ -135,7 +138,7 @@ void ggml_cuda_mul_mat_q(
|
|||
}
|
||||
|
||||
// Stride depends on quantization format
|
||||
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
|
||||
const int64_t s12 = use_native_mxfp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
|
||||
(4 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 128 values
|
||||
:
|
||||
|
|
@ -187,7 +190,7 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
|
||||
if (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) {
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
|
|
@ -197,7 +200,7 @@ void ggml_cuda_mul_mat_q(
|
|||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = (blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4) ?
|
||||
const int64_t s12 = use_native_mxfp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (4 * QK_MXFP4 * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
|
|
|||
|
|
@ -218,8 +218,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
case GGML_TYPE_MXFP4:
|
||||
return MMQ_MMA_TILE_X_K_FP4;
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4;
|
||||
#else
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
#endif
|
||||
|
|
@ -784,7 +783,6 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
int * x_qs = (int *) x_tile;
|
||||
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
|
||||
|
||||
|
|
@ -833,7 +831,6 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|||
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
|
|
@ -1026,7 +1023,7 @@ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __res
|
|||
const int * y_qs = (const int *) y + 2;
|
||||
const uint32_t * y_sc = (const uint32_t *) y; // E8M0 scales for Y
|
||||
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale pe rtile
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // 2 x 4 A tiles. Per warp there will be 1 scale per tile
|
||||
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI8_0)]; // per tile you would only have 1 scale per thread
|
||||
|
||||
// Block scale
|
||||
|
|
@ -3253,7 +3250,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
#endif
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
|
|
@ -3386,15 +3383,21 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
constexpr bool use_native_mxfp4 = (type == GGML_TYPE_MXFP4);
|
||||
#else
|
||||
constexpr bool use_native_mxfp4 = false;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILBLE)
|
||||
|
||||
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
|
||||
constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
|
||||
constexpr size_t y_stride = type == GGML_TYPE_MXFP4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
|
||||
constexpr size_t sz = use_native_mxfp4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq);
|
||||
constexpr size_t y_stride = use_native_mxfp4 ? MMQ_TILE_Y_FP4_K : MMQ_TILE_Y_K;
|
||||
|
||||
constexpr int y_block_stride =
|
||||
type == GGML_TYPE_MXFP4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
|
||||
constexpr int y_block_stride = use_native_mxfp4 ? (sz / sizeof(int)) // 18 ints per block_fp4_mmq (covers 128 values = 4 qk-blocks)
|
||||
:
|
||||
(qk * sz / (4 * QK8_1 * sizeof(int))); // original formula for Q8_1
|
||||
|
||||
|
|
@ -3402,7 +3405,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
||||
{
|
||||
const int * by0 =
|
||||
type == GGML_TYPE_MXFP4 ?
|
||||
use_native_mxfp4 ?
|
||||
y + ncols_y * ((kb0 / 4) * y_block_stride) // kb0/4 for MXFP4 since 4 qk-blocks per block_fp4_mmq
|
||||
:
|
||||
y + ncols_y * (kb0 * y_block_stride); // original for Q8_1
|
||||
|
|
@ -3422,7 +3425,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
|
||||
{
|
||||
const int * by0 =
|
||||
type == GGML_TYPE_MXFP4 ?
|
||||
use_native_mxfp4 ?
|
||||
y + ncols_y * ((kb0 / 4) * y_block_stride + y_block_stride) // advance by one block_fp4_mmq
|
||||
:
|
||||
y + ncols_y * (kb0 * y_block_stride +
|
||||
|
|
|
|||
Loading…
Reference in New Issue