diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 0ce95c7006..b8b1d9aefd 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -782,22 +782,23 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); #if defined(BLACKWELL_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); // Same offset as original: 2*MMQ_TILE_NE_K + uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); - constexpr int nrows = 1; - const int txi = threadIdx.x; // txi - const int kbx = txi; + const int txi = threadIdx.x; - // TODO: only 8 threads of a warp at the moment for simplicity, use more threads - if (txi >= 8) { - return; - } -# pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nrows * nwarps) { - int i = i0 + threadIdx.y; + // Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration + constexpr int threads_per_row = 8; // 8 blocks per row + constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp + const int kbx = txi % threads_per_row; // block id 0-7 + const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; if (need_check) { i = min(i, i_max); @@ -805,33 +806,32 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx; - // Load packed FP4 data directly (no LUT dequantization) - const int aux_q4_0 = get_int_b1(bxi->qs, 0); - const int aux_q4_1 = get_int_b1(bxi->qs, 1); - const int aux_q4_2 = get_int_b1(bxi->qs, 2); - const int aux_q4_3 = get_int_b1(bxi->qs, 3); + // Load 16 bytes more efficiently using memcpy (compiler optimizes to vector loads) + int aux_q4[4]; + memcpy(aux_q4, bxi->qs, 16); + // Compress: extract low nibbles from each byte and pack into 16 bits + // Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0] + // Output: [lo3|lo2|lo1|lo0] as 16 bits const auto compress = [](const int x) -> int { - uint16_t a = (x >> 24) & 0xF; - uint16_t b = (x >> 16) & 0xF; - uint16_t c = (x >> 8) & 0xF; - uint16_t d = x & 0xF; - - return (a << 12) | (b << 8) | (c << 4) | d; + const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0 + // Pack nibbles: shift and combine + const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0 + return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0 }; - const int k0 = kbx * 4; // each block takes 4 bytes + const int k0 = kbx * 4; - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4_1) << 16 | compress(aux_q4_0); - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4_3) << 16 | compress(aux_q4_2); - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4_1 >> 4) << 16 | compress(aux_q4_0 >> 4); - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4_3 >> 4) << 16 | compress(aux_q4_2 >> 4); + x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]); + x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]); + x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4); + x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4); - if (txi % 2 == 0) { + // Load E8M0 scales: pack 2 consecutive scales into one uint32 + if (kbx % 2 == 0) { uint32_t e = bxi->e; - bxi++; - e |= (bxi->e << 8); - x_sc[i * MMQ_MMA_TILE_X_K_FP4 + txi / 2] = e; + e |= ((bxi + 1)->e << 8); + x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e; } } #endif