vulkan: skip all-negative-inf blocks in FA (#17186)
This commit is contained in:
parent
38eaf32af1
commit
234ae7d7bd
|
|
@ -521,6 +521,7 @@ struct vk_device_struct {
|
||||||
bool subgroup_shuffle;
|
bool subgroup_shuffle;
|
||||||
bool subgroup_ballot;
|
bool subgroup_ballot;
|
||||||
bool subgroup_clustered;
|
bool subgroup_clustered;
|
||||||
|
bool subgroup_vote;
|
||||||
bool multi_add;
|
bool multi_add;
|
||||||
bool shader_int64;
|
bool shader_int64;
|
||||||
bool buffer_device_address;
|
bool buffer_device_address;
|
||||||
|
|
@ -4188,6 +4189,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
|
||||||
|
|
||||||
|
device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||||
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
|
||||||
|
|
||||||
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
||||||
|
|
||||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||||
|
|
@ -13572,8 +13576,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!coopmat2 && !device->subgroup_shuffle) {
|
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
|
||||||
// scalar FA uses subgroupShuffle
|
// scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||||
|
|
||||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_vote : enable
|
||||||
|
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
#include "flash_attn_base.glsl"
|
#include "flash_attn_base.glsl"
|
||||||
|
|
@ -108,6 +109,38 @@ void main() {
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||||
|
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||||
|
|
||||||
|
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t c = (idx + tid) % Bc;
|
||||||
|
uint32_t r = (idx + tid) / Bc;
|
||||||
|
if (idx + tid < Bc * Br) {
|
||||||
|
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||||
|
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||||
|
masksh[c][r] = m;
|
||||||
|
max_mask = max(max_mask, m);
|
||||||
|
} else {
|
||||||
|
masksh[c][r] = float(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// skip the block if the mask is entirely -inf
|
||||||
|
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||||
|
barrier();
|
||||||
|
if (gl_SubgroupInvocationID == 0) {
|
||||||
|
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||||
|
max_mask = max(max_mask, tmpsh[s]);
|
||||||
|
}
|
||||||
|
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
float Sf[Br][cols_per_thread];
|
float Sf[Br][cols_per_thread];
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
|
|
@ -153,21 +186,6 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
||||||
uint32_t c = (idx + tid) % Bc;
|
|
||||||
uint32_t r = (idx + tid) / Bc;
|
|
||||||
if (idx + tid < Bc * Br) {
|
|
||||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
|
||||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
|
||||||
} else {
|
|
||||||
masksh[c][r] = float(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||||
|
|
||||||
#extension GL_KHR_shader_subgroup_basic : enable
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_vote : enable
|
||||||
#extension GL_KHR_memory_scope_semantics : enable
|
#extension GL_KHR_memory_scope_semantics : enable
|
||||||
#extension GL_KHR_cooperative_matrix : enable
|
#extension GL_KHR_cooperative_matrix : enable
|
||||||
|
|
||||||
|
|
@ -148,6 +149,37 @@ void main() {
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
|
float mask_cache[Bc * Br / WorkGroupSize];
|
||||||
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||||
|
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||||
|
|
||||||
|
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||||
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||||
|
uint32_t c = (idx + tid) % Bc;
|
||||||
|
uint32_t r = (idx + tid) / Bc;
|
||||||
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||||
|
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||||
|
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||||
|
mask_cache[idx / WorkGroupSize] = m;
|
||||||
|
max_mask = max(max_mask, m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// skip the block if the mask is entirely -inf
|
||||||
|
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||||
|
barrier();
|
||||||
|
if (gl_SubgroupInvocationID == 0) {
|
||||||
|
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||||
|
max_mask = max(max_mask, tmpsh[s]);
|
||||||
|
}
|
||||||
|
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (HSK / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t c = (idx + tid) / (HSK / 4);
|
uint32_t c = (idx + tid) / (HSK / 4);
|
||||||
|
|
@ -208,7 +240,8 @@ void main() {
|
||||||
uint32_t r = (idx + tid) / Bc;
|
uint32_t r = (idx + tid) / Bc;
|
||||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
float f = mask_cache[idx / WorkGroupSize];
|
||||||
|
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||||
return max(x, y);
|
return max(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
|
||||||
|
return max(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
@ -142,6 +146,44 @@ void main() {
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||||
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||||
|
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||||
|
|
||||||
|
if (nem1_bounds_check) {
|
||||||
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||||
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||||
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||||
|
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||||
|
|
||||||
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
||||||
|
|
||||||
|
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||||
|
|
||||||
|
// skip the block if the mask is entirely -inf
|
||||||
|
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||||
|
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||||
|
// Don't clamp against nem1 when GQA is enabled
|
||||||
|
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||||
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||||
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||||
|
|
||||||
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||||
|
|
||||||
|
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||||
|
|
||||||
|
// skip the block if the mask is entirely -inf
|
||||||
|
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||||
|
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||||
|
|
@ -158,31 +200,7 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||||
|
|
||||||
if (nem1_bounds_check) {
|
|
||||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
||||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
||||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
||||||
|
|
||||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
||||||
|
|
||||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
||||||
} else {
|
|
||||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
||||||
// Don't clamp against nem1 when GQA is enabled
|
|
||||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
|
||||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
|
||||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
||||||
|
|
||||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
||||||
|
|
||||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear padding elements to -inf, so they don't contribute to rowmax
|
// Clear padding elements to -inf, so they don't contribute to rowmax
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue