dynamic subgroups for intel
This commit is contained in:
parent
b626e3296d
commit
4819fd3014
|
|
@ -3191,6 +3191,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL;
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
|
|
@ -3206,10 +3208,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
if (disable_subgroups) {
|
||||
wg_size = 128;
|
||||
} else {
|
||||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if (device->subgroup_size > 32 && rows_cols[0] < 4) {
|
||||
if (disable_subgroups) {
|
||||
wg_size = 128;
|
||||
} else if (device->subgroup_size > 32 && rows_cols[0] < 4) {
|
||||
wg_size = device->subgroup_size * 2;
|
||||
} else {
|
||||
wg_size = device->subgroup_size * 4;
|
||||
|
|
@ -3227,7 +3235,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags};
|
||||
const uint32_t subgroup_size = disable_subgroups ? 0xFFFFFFFF : device->subgroup_size;
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags};
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
|
|
@ -3243,15 +3252,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_
|
|||
return elem;
|
||||
}
|
||||
|
||||
const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1;
|
||||
// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions
|
||||
const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize;
|
||||
shared float tmpsh[tmpsh_size];
|
||||
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
|
||||
|
||||
|
|
@ -67,6 +68,7 @@ void main() {
|
|||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||
|
||||
|
|
@ -359,20 +361,33 @@ void main() {
|
|||
float rowmaxf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
|
||||
}
|
||||
if (row_split == 1) {
|
||||
// Reduce inside workgroup with shmem
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
|
||||
if (SubGroupSize != SUBGROUPS_DISABLED) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
|
||||
}
|
||||
barrier();
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
||||
if (row_split == 1) {
|
||||
// Reduce inside workgroup with shmem
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
|
||||
}
|
||||
barrier();
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpsh[tid] = rowmaxf;
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
|
||||
float Moldf = Mf[r];
|
||||
|
|
@ -385,37 +400,64 @@ void main() {
|
|||
Lf[r] = eMf*Lf[r];
|
||||
|
||||
// Compute sum across the row
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Lf[r] += subgroupShuffleXor(Lf[r], s);
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
|
||||
if (SubGroupSize != SUBGROUPS_DISABLED) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Lf[r] += subgroupShuffleXor(Lf[r], s);
|
||||
}
|
||||
barrier();
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Lf[r] += tmpsh[s * D_split + d_tid];
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
|
||||
}
|
||||
barrier();
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Lf[r] += tmpsh[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpsh[tid] = Lf[r];
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] = ACC_TYPE(eMf) * Of[r][d];
|
||||
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
|
||||
if (SubGroupSize != SUBGROUPS_DISABLED) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
||||
}
|
||||
barrier();
|
||||
Of[r][d] = tmpsh_accv4[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
Of[r][d] = tmpsh_accv4[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpsh_accv4[tid] = Of[r][d];
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
Of[r][d] += tmpsh_accv4[tid ^ s];
|
||||
tmpsh_accv4[tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpsh_accv4[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ layout (push_constant) uniform parameter {
|
|||
#define SINK_ENABLE_BIT (1<<24)
|
||||
#define N_LOG2_MASK 0xFFFF
|
||||
|
||||
#define SUBGROUPS_DISABLED 0xFFFFFFFF
|
||||
|
||||
layout (binding = 4) readonly buffer S {float data_s[];};
|
||||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
|
|
|||
Loading…
Reference in New Issue