#version 450 #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_16bit_storage : require #ifdef USE_SUBGROUPS #extension GL_KHR_shader_subgroup_basic : require #extension GL_KHR_shader_subgroup_clustered : require #define INVOCATION_ID gl_SubgroupInvocationID.x #else #define INVOCATION_ID gl_LocalInvocationID.x #endif layout (push_constant) uniform parameter { uint ne; uint num_blocks; } p; #include "types.glsl" layout(constant_id = 0) const uint GROUP_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {vec4 data_a[];}; #ifndef QBLOCK_X4 layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; #else layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; #endif #ifndef USE_SUBGROUPS shared float shmem[GROUP_SIZE]; #endif void quantize(const uint wgid) { const uint tid = INVOCATION_ID; // Each thread handles a vec4, so 8 threads handle a block const uint blocks_per_group = GROUP_SIZE / 8; const uint block_in_wg = tid / 8; const uint ib = wgid * blocks_per_group + block_in_wg; const uint iqs = tid % 8; #ifdef QBLOCK_X4 const uint ibx4_outer = ib / 4; const uint ibx4_inner = ib % 4; const uint required_x4_blocks = (p.ne + 127) / 128; if (ibx4_outer >= required_x4_blocks) { return; } #endif const uint a_idx = ib * 8 + iqs; vec4 vals = a_idx < p.ne / 4 ? data_a[a_idx] : vec4(0.0f); const vec4 abs_vals = abs(vals); // Find absolute max for each block const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); #ifndef USE_SUBGROUPS shmem[tid] = thread_max; barrier(); [[unroll]] for (uint s = 4; s > 0; s >>= 1) { if (iqs < s) { shmem[tid] = max(shmem[tid], shmem[tid + s]); } barrier(); } const float amax = shmem[block_in_wg * 8]; #else const float amax = subgroupClusteredMax(thread_max, 8); #endif const float d = amax / 127.0; const float d_inv = d != 0.0 ? 1.0 / d : 0.0; vals = round(vals * d_inv); #ifndef QBLOCK_X4 data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); #else data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals))); #endif #ifndef USE_SUBGROUPS barrier(); #endif // Calculate the sum for each block const float thread_sum = vals.x + vals.y + vals.z + vals.w; #ifndef USE_SUBGROUPS shmem[tid] = thread_sum; barrier(); [[unroll]] for (uint s = 4; s > 0; s >>= 1) { if (iqs < s) { shmem[tid] += shmem[tid + s]; } barrier(); } #else const float sum = subgroupClusteredAdd(thread_sum, 8); #endif if (iqs == 0) { #ifndef USE_SUBGROUPS const float sum = shmem[tid]; #endif #ifndef QBLOCK_X4 data_b[ib].ds = f16vec2(vec2(d, sum * d)); #else data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d)); #endif } } void main() { uint wgid = gl_WorkGroupID.x; while (wgid < p.num_blocks) { quantize(wgid); wgid += gl_NumWorkGroups.x; } }