better buf index

This commit is contained in:
Eve 2026-01-01 18:08:51 -05:00
parent 6fd091106e
commit bcdeea47a1
1 changed files with 6 additions and 6 deletions

View File

@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@ -63,7 +63,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@ -79,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@ -96,7 +96,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@ -459,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@ -473,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
kvalues_iq4nl[vui >> 12]);
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;