dynamic subgroups for intel

This commit is contained in:
Ruben Ortlam 2026-02-08 15:41:32 +01:00
parent b626e3296d
commit 4819fd3014
3 changed files with 95 additions and 42 deletions

View File

@ -3191,6 +3191,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
};
const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL;
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
@ -3206,10 +3208,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
break;
case FA_COOPMAT1:
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
if (disable_subgroups) {
wg_size = 128;
} else {
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
}
break;
default:
if (device->subgroup_size > 32 && rows_cols[0] < 4) {
if (disable_subgroups) {
wg_size = 128;
} else if (device->subgroup_size > 32 && rows_cols[0] < 4) {
wg_size = device->subgroup_size * 2;
} else {
wg_size = device->subgroup_size * 4;
@ -3227,7 +3235,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
// AMD prefers loading K directly from global memory
const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags};
const uint32_t subgroup_size = disable_subgroups ? 0xFFFFFFFF : device->subgroup_size;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@ -3243,15 +3252,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
} \
} \
} \

View File

@ -43,7 +43,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_
return elem;
}
const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1;
// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions
const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize;
shared float tmpsh[tmpsh_size];
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
@ -67,6 +68,7 @@ void main() {
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
@ -359,20 +361,33 @@ void main() {
float rowmaxf = Mf[r];
// Compute max across the row
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
}
if (row_split == 1) {
// Reduce inside workgroup with shmem
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
if (SubGroupSize != SUBGROUPS_DISABLED) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
}
barrier();
rowmaxf = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
if (row_split == 1) {
// Reduce inside workgroup with shmem
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
}
barrier();
rowmaxf = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
}
}
} else {
barrier();
tmpsh[tid] = rowmaxf;
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
}
barrier();
}
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
}
float Moldf = Mf[r];
@ -385,37 +400,64 @@ void main() {
Lf[r] = eMf*Lf[r];
// Compute sum across the row
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Lf[r] += subgroupShuffleXor(Lf[r], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
if (SubGroupSize != SUBGROUPS_DISABLED) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Lf[r] += subgroupShuffleXor(Lf[r], s);
}
barrier();
Lf[r] = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Lf[r] += tmpsh[s * D_split + d_tid];
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
}
barrier();
Lf[r] = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Lf[r] += tmpsh[s * D_split + d_tid];
}
}
} else {
barrier();
tmpsh[tid] = Lf[r];
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
}
barrier();
}
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf) * Of[r][d];
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
if (SubGroupSize != SUBGROUPS_DISABLED) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
}
barrier();
Of[r][d] = tmpsh_accv4[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
}
barrier();
Of[r][d] = tmpsh_accv4[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
}
}
} else {
barrier();
tmpsh_accv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
Of[r][d] += tmpsh_accv4[tid ^ s];
tmpsh_accv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpsh_accv4[row_tid * threads_per_rowgroup + d_tid];
}
}
}

View File

@ -66,6 +66,8 @@ layout (push_constant) uniform parameter {
#define SINK_ENABLE_BIT (1<<24)
#define N_LOG2_MASK 0xFFFF
#define SUBGROUPS_DISABLED 0xFFFFFFFF
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};