This commit is contained in:
Eve 2026-01-02 16:32:27 -06:00 committed by GitHub
commit a578deb3c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 42 deletions

View File

@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
return dm;
}
#endif

View File

@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
const vec2 dm = vec2(data_a_packed32[ib].dm);
const uint vui = data_a_packed32[ib].qs[iqs];
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint uint_qh = data_a_packed16[ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
const vec2 dm = vec2(data_a_packed32[ib].dm);
const uint uint_qh = data_a_packed32[ib].qh;
const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
const uint vui = data_a_packed32[ib].qs[iqs];
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
const uint scales = data_a[ib].scales[scalesi];
const vec2 dm = vec2(data_a[ib].dm);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q3_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
fma(d, q.y, m));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
const vec2 q = vec2(unpack8(qs | qh).xy);
const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
const vec4 q = vec4(unpack8(qs | qh));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
fma(d, q.y, m));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
kvalues_iq4nl[vui >> 12]);
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;

View File

@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";
if (tname == "bf16") {