optimize load_tiles
This commit is contained in:
parent
e214110ef7
commit
41e876a24f
|
|
@ -782,22 +782,23 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
int * x_qs = (int *) x_tile;
|
||||
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K); // Same offset as original: 2*MMQ_TILE_NE_K
|
||||
uint32_t * x_sc = (uint32_t *) (x_qs + MMQ_TILE_NE_K);
|
||||
|
||||
constexpr int nrows = 1;
|
||||
const int txi = threadIdx.x; // txi
|
||||
const int kbx = txi;
|
||||
const int txi = threadIdx.x;
|
||||
|
||||
// TODO: only 8 threads of a warp at the moment for simplicity, use more threads
|
||||
if (txi >= 8) {
|
||||
return;
|
||||
}
|
||||
# pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nrows * nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
// Use all 32 threads: 8 threads per row, process 4 rows per warp per iteration
|
||||
constexpr int threads_per_row = 8; // 8 blocks per row
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row; // 4 rows per warp
|
||||
const int kbx = txi % threads_per_row; // block id 0-7
|
||||
const int row_in_warp = txi / threads_per_row; // which of the 4 rows this thread handles
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
|
|
@ -805,33 +806,32 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
||||
|
||||
// Load packed FP4 data directly (no LUT dequantization)
|
||||
const int aux_q4_0 = get_int_b1(bxi->qs, 0);
|
||||
const int aux_q4_1 = get_int_b1(bxi->qs, 1);
|
||||
const int aux_q4_2 = get_int_b1(bxi->qs, 2);
|
||||
const int aux_q4_3 = get_int_b1(bxi->qs, 3);
|
||||
// Load 16 bytes more efficiently using memcpy (compiler optimizes to vector loads)
|
||||
int aux_q4[4];
|
||||
memcpy(aux_q4, bxi->qs, 16);
|
||||
|
||||
// Compress: extract low nibbles from each byte and pack into 16 bits
|
||||
// Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
|
||||
// Output: [lo3|lo2|lo1|lo0] as 16 bits
|
||||
const auto compress = [](const int x) -> int {
|
||||
uint16_t a = (x >> 24) & 0xF;
|
||||
uint16_t b = (x >> 16) & 0xF;
|
||||
uint16_t c = (x >> 8) & 0xF;
|
||||
uint16_t d = x & 0xF;
|
||||
|
||||
return (a << 12) | (b << 8) | (c << 4) | d;
|
||||
const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0
|
||||
// Pack nibbles: shift and combine
|
||||
const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0
|
||||
return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0
|
||||
};
|
||||
|
||||
const int k0 = kbx * 4; // each block takes 4 bytes
|
||||
const int k0 = kbx * 4;
|
||||
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4_1) << 16 | compress(aux_q4_0);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4_3) << 16 | compress(aux_q4_2);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4_1 >> 4) << 16 | compress(aux_q4_0 >> 4);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4_3 >> 4) << 16 | compress(aux_q4_2 >> 4);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4);
|
||||
|
||||
if (txi % 2 == 0) {
|
||||
// Load E8M0 scales: pack 2 consecutive scales into one uint32
|
||||
if (kbx % 2 == 0) {
|
||||
uint32_t e = bxi->e;
|
||||
bxi++;
|
||||
e |= (bxi->e << 8);
|
||||
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + txi / 2] = e;
|
||||
e |= ((bxi + 1)->e << 8);
|
||||
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue