optimize load_tiles

This commit is contained in:
Aman Gupta 2025-12-10 16:12:43 +01:00
parent e214110ef7
commit 41e876a24f
1 changed files with 31 additions and 31 deletions

View File

@ -782,22 +782,23 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(BLACKWELL_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); // Same offset as original: 2*MMQ_TILE_NE_K
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
constexpr int nrows = 1;
const int txi = threadIdx.x; // txi
const int kbx = txi;
const int txi = threadIdx.x;
// TODO: only 8 threads of a warp at the moment for simplicity, use more threads
if (txi >= 8) {
return;
}
# pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nrows * nwarps) {
int i = i0 + threadIdx.y;
// Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration
constexpr int threads_per_row = 8; // 8 blocks per row
constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp
const int kbx = txi % threads_per_row; // block id 0-7
const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if (need_check) {
i = min(i, i_max);
@ -805,33 +806,32 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
// Load packed FP4 data directly (no LUT dequantization)
const int aux_q4_0 = get_int_b1(bxi->qs, 0);
const int aux_q4_1 = get_int_b1(bxi->qs, 1);
const int aux_q4_2 = get_int_b1(bxi->qs, 2);
const int aux_q4_3 = get_int_b1(bxi->qs, 3);
// Load 16 bytes more efficiently using memcpy (compiler optimizes to vector loads)
int aux_q4[4];
memcpy(aux_q4, bxi->qs, 16);
// Compress: extract low nibbles from each byte and pack into 16 bits
// Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
// Output: [lo3|lo2|lo1|lo0] as 16 bits
const auto compress = [](const int x) -> int {
uint16_t a = (x >> 24) & 0xF;
uint16_t b = (x >> 16) & 0xF;
uint16_t c = (x >> 8) & 0xF;
uint16_t d = x & 0xF;
return (a << 12) | (b << 8) | (c << 4) | d;
const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0
// Pack nibbles: shift and combine
const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0
return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0
};
const int k0 = kbx * 4; // each block takes 4 bytes
const int k0 = kbx * 4;
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4_1) << 16 | compress(aux_q4_0);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4_3) << 16 | compress(aux_q4_2);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4_1 >> 4) << 16 | compress(aux_q4_0 >> 4);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4_3 >> 4) << 16 | compress(aux_q4_2 >> 4);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4);
if (txi % 2 == 0) {
// Load E8M0 scales: pack 2 consecutive scales into one uint32
if (kbx % 2 == 0) {
uint32_t e = bxi->e;
bxi++;
e |= (bxi->e << 8);
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + txi / 2] = e;
e |= ((bxi + 1)->e << 8);
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
}
}
#endif