vulkan: Flash Attention DP4A shader for quantized KV cache (#20797)

* use integer dot product for quantized KV flash attention

* small improvements

* fix SHMEM_STAGING indexing

* add missing KV type quants

* fixes

* add supported quants to FA tests

* readd fast paths for <8bit quants

* fix mmq gate and shmem checks
This commit is contained in:
Ruben Ortlam 2026-04-13 14:21:31 +02:00 committed by GitHub
parent aa00911d12
commit 75f3bc94e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 431 additions and 28 deletions

View File

@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
}
};
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
GGML_UNUSED(kv_type);
vk_fa_tuning_params result{};
result.path = FA_SCALAR;
@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
result.block_rows /= 2;
}
@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
}
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
GGML_UNUSED(f32acc);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
uint32_t Qf, kvsh, kblocksh_size;
if (mmq) {
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
Qf = Br * (hsk / 32) * block_b_size;
const uint32_t D = std::max(hsk, hsv);
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
const uint32_t D = std::max(hsk, hsv);
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
kblocksh_size = 0;
}
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}

View File

@ -10,6 +10,13 @@
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif
#ifdef MMQ
#extension GL_EXT_integer_dot_product : require
#extension GL_KHR_shader_subgroup_clustered : require
#include "mul_mmq_shmem_types.glsl"
#endif
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
#ifndef MMQ
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
#else
const uint32_t qf_stride = HSK / 32;
shared block_b_cache Qf[Br * qf_stride];
#endif
#ifndef MMQ
const uint32_t D = HSK > HSV ? HSK : HSV;
#else
const uint32_t D = HSV;
#endif
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
#ifdef MMQ
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
#endif
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
#ifdef MMQ
#include "flash_attn_mmq_funcs.glsl"
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@ -82,10 +108,39 @@ void main() {
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
#ifndef MMQ
if (is_in_bounds) {
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
#else
const uint buf_ib = r * qf_stride + d / 8;
const uint buf_iqs = d % 8;
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
const FLOAT_TYPEV4 abs_vals = abs(vals);
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
vals = round(vals * qd_inv);
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
}
#else // Q4_0, Q4_1, Q5_0, Q5_1
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
if (buf_iqs == 0) {
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
}
#endif
#endif
}
barrier();
@ -195,6 +250,7 @@ void main() {
if (SHMEM_STAGING != 0) {
barrier();
#ifndef MMQ
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@ -214,9 +270,29 @@ void main() {
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
#else // MMQ
const uint ints_per_block = 8 / QUANT_R_MMQ;
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
const uint32_t iqs = (idx + tid) % ints_per_block;
const uint32_t ib = (idx + tid) / ints_per_block;
const uint32_t c = ib / (HSK / 32);
const uint32_t block = ib % (HSK / 32);
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
const uint buf_ib = c * qf_stride + block;
if (!KV_bounds_check || j * Bc + c < KV) {
const uint global_ib = (j * Bc + c) * k_stride + block;
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
} else {
k_block_to_shmem_zero(buf_ib, iqs);
}
}
}
#endif // MMQ
barrier();
}
#ifndef MMQ
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
@ -275,6 +351,110 @@ void main() {
}
}
}
#else // MMQ
const uint hsk4 = HSK_per_thread / 4;
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
(hsk4 % 4 == 0) ? 4 :
(hsk4 % 2 == 0) ? 2 : 1;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
int32_t k_quants[d_per_step];
ACC_TYPEV2 k_dm;
if (SHMEM_STAGING != 0) {
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
#else
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
uint vui = kblocksh[buf_ib].qs[d];
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
}
}
} else {
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
const uint ib = coord / BLOCK_SIZE;
const uint iqs = (coord % BLOCK_SIZE);
#if QUANT_AUXF == 1
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
#else
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
if (d_per_step == 8) {
#if defined(DATA_A_Q5_0)
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
k_packed.k_data_packed16[k_offset + ib].qh[1]));
#elif defined(DATA_A_Q5_1)
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
#endif
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
#if defined(A_TYPE_PACKED32)
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
#else
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
#endif
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint qh_lo = (qh >> (d * 4)) & 0xF;
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
#endif
}
} else
#endif
{
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
}
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
int32_t acc = 0;
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
}
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
Sf[r][c] += k_dot_correction(qib, k_dm);
}
}
}
}
#endif // MMQ
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split

View File

@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
#endif
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif

View File

@ -0,0 +1,149 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q4_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
return int32_t(vui & 0x0F0F0F0F);
}
#endif
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
#ifdef DATA_A_Q5_0
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
k_packed.k_data_packed16[a_offset + ib].qh[1]));
#else
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
#endif
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
uint qh_bits = (qh >> iqs) & 0xF;
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
}
#endif
#if defined(DATA_A_Q8_0)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
}
#endif
#if defined(DATA_A_IQ4_NL)
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
uint shift = (iqs & 0x10) >> 2;
vui >>= shift;
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
kvalues_iq4nl_const[idx.y],
kvalues_iq4nl_const[idx.z],
kvalues_iq4nl_const[idx.w]));
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
}
#else
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
}
#endif
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
#if defined(DATA_A_Q4_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_Q4_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
#elif defined(DATA_A_Q5_0)
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
}
#elif defined(DATA_A_Q5_1)
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
if (iqs == 0) {
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
}
#elif defined(DATA_A_Q8_0)
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
#elif defined(DATA_A_IQ4_NL)
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
#endif
}
}
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
uint sub = pos % 4;
uint shift = ((pos % 8) >= 4) ? 4 : 0;
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
return kblocksh[buf_ib].qs[pos];
#endif
}
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
#if defined(DATA_A_Q4_0)
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q5_0)
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
#else
return ACC_TYPE(0.0);
#endif
}
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
kblocksh[buf_ib].qs[iqs] = 0;
#if defined(DATA_A_IQ4_NL)
kblocksh[buf_ib].qs[iqs + 4] = 0;
#endif
if (iqs == 0) {
#if QUANT_AUXF == 1
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
#else
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
#endif
}
}

View File

@ -32,6 +32,12 @@ struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_IQ4_NL)
#define QUANT_R_MMQ 2
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct block_a_cache {

View File

@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
#if defined(DATA_A_IQ4_NL)
#define QUANT_K QUANT_K_IQ4_NL
#define QUANT_R QUANT_R_IQ4_NL
#define QUANT_AUXF 1
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif

View File

@ -406,8 +406,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
}
static std::vector<std::future<void>> compiles;
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
std::string out_path = join_paths(output_dir, name + ".spv");
if (input_filepath == "") {
@ -625,15 +625,16 @@ void process_shaders() {
for (const bool& fp16 : {false, true}) {
std::map<std::string, std::string> base_dict;
if (fp16) {
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
} else {
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
}
// flash attention
for (const bool& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
if (fp16 && f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
@ -672,6 +673,12 @@ void process_shaders() {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (tname != "f32") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
}
#endif
}
}
}

View File

@ -8580,7 +8580,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 75, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) {
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));