use integer dot product for quantized KV flash attention

This commit is contained in:
Ruben Ortlam 2026-03-19 14:22:31 +01:00
parent cca358b43b
commit 2df9eca0cb
3 changed files with 217 additions and 10 deletions

View File

@ -3435,13 +3435,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
if (device->integer_dot_product) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
} else {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
if (device->integer_dot_product) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
} else {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
}
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {

View File

@ -10,6 +10,13 @@
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif
#ifdef MMQ
#extension GL_EXT_integer_dot_product : require
#extension GL_KHR_shader_subgroup_clustered : require
#include "mul_mmq_shmem_types.glsl"
#endif
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
@ -41,15 +48,53 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
#ifndef MMQ
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
#else
const uint32_t qf_stride = HSK / 32;
shared block_b_cache Qf[Br * qf_stride];
#endif
#ifndef MMQ
const uint32_t D = HSK > HSV ? HSK : HSV;
#else
const uint32_t D = HSV;
#endif
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
#ifdef MMQ
shared block_a_cache kblocksh[Bc * qf_stride];
#endif
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
#ifdef MMQ
#if defined(DATA_A_Q4_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
#endif
#if defined(DATA_A_Q8_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
}
#endif
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
}
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@ -82,10 +127,33 @@ void main() {
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
#ifndef MMQ
if (is_in_bounds) {
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
#else
const uint buf_ib = r * qf_stride + d / 8;
const uint buf_iqs = d % 8;
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
const FLOAT_TYPEV4 abs_vals = abs(vals);
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
vals = round(vals * qd_inv);
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
#endif
}
barrier();
@ -195,6 +263,7 @@ void main() {
if (SHMEM_STAGING != 0) {
barrier();
#ifndef MMQ
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@ -214,9 +283,43 @@ void main() {
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
const uint32_t ib = (idx + tid) / ints_per_block;
const uint32_t c = ib / (HSK / 32);
const uint32_t block = ib % (HSK / 32);
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
const uint buf_ib = c * qf_stride + block;
if (!KV_bounds_check || j * Bc + c < KV) {
const uint global_ib = (j * Bc + c) * k_stride + block;
#if defined(DATA_A_Q4_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[k_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[k_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_Q8_0)
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[k_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[k_offset + global_ib].qs[iqs * 2 + 1]));
#endif
if (iqs == 0) {
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[k_offset + global_ib].d);
}
} else {
kblocksh[buf_ib].qs[iqs] = 0;
if (iqs == 0) {
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
}
}
}
}
#endif // MMQ
barrier();
}
#ifndef MMQ
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
@ -275,6 +378,94 @@ void main() {
}
}
}
#else // MMQ
const uint hsk4 = HSK_per_thread / 4;
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
(hsk4 % 4 == 0) ? 4 :
(hsk4 % 2 == 0) ? 2 : 1;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
int32_t k_quants[d_per_step];
ACC_TYPE k_scale;
if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
k_scale = ACC_TYPE(kblocksh[buf_ib].dm);
#if defined(DATA_A_Q4_0)
if (d_per_step < 8) {
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
uint sub = (d_block + d) % 4;
uint vui = kblocksh[buf_ib].qs[sub];
uint shift = ((d_block + d) >= 4) ? 4 : 0;
k_quants[d] = int32_t((vui >> shift) & 0x0F0F0F0F);
}
} else {
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
}
}
#elif defined(DATA_A_Q8_0)
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = kblocksh[buf_ib].qs[d_block % 8 + d];
}
#endif
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);
k_scale = ACC_TYPE(get_k_d(ib, k_offset));
#if defined(DATA_A_Q4_0)
if (d_per_step < 8) {
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
} else {
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = (uint(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]) << 16) |
uint(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0]);
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
}
}
#elif defined(DATA_A_Q8_0)
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
#else
#error unimplemented
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
int32_t acc = 0;
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
}
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_scale;
#if defined(DATA_A_Q4_0)
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
Sf[r][c] -= ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_scale;
}
#endif
}
}
}
#endif // MMQ
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split

View File

@ -404,8 +404,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
}
static std::vector<std::future<void>> compiles;
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
std::string out_path = join_paths(output_dir, name + ".spv");
if (input_filepath == "") {
@ -623,9 +623,9 @@ void process_shaders() {
for (const bool& fp16 : {false, true}) {
std::map<std::string, std::string> base_dict;
if (fp16) {
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
} else {
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
}
// flash attention
@ -670,6 +670,12 @@ void process_shaders() {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (tname != "f32") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
}
#endif
}
}
}