vulkan: fix OOB check in flash_attn_mask_opt (#20296)
This commit is contained in:
parent
5866e3bbc8
commit
aa429cf507
|
|
@ -8840,7 +8840,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
|
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
|
||||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
||||||
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,28 +33,11 @@ layout (push_constant) uniform parameter {
|
||||||
shared float minsh[NUM_SUBGROUPS];
|
shared float minsh[NUM_SUBGROUPS];
|
||||||
shared float maxsh[NUM_SUBGROUPS];
|
shared float maxsh[NUM_SUBGROUPS];
|
||||||
|
|
||||||
// For each Br x Bc block of the mask (input) buffer, read all values and check
|
|
||||||
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
|
|
||||||
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
|
|
||||||
// 32-bit result mask.
|
|
||||||
//
|
|
||||||
// TODO: This is a lot of work per workgroup, might make sense to split this into
|
|
||||||
// more workgroups in the future.
|
|
||||||
void main() {
|
|
||||||
// Each workgroup handles a row
|
|
||||||
const uint tid = gl_LocalInvocationIndex;
|
|
||||||
const uint i0 = gl_WorkGroupID.x;
|
|
||||||
const uint i1 = gl_WorkGroupID.y;
|
|
||||||
const uint i2 = gl_WorkGroupID.z % nem2;
|
|
||||||
const uint i3 = gl_WorkGroupID.z / nem2;
|
|
||||||
|
|
||||||
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
|
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
|
||||||
|
|
||||||
uint result = 0;
|
void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
|
||||||
|
const uint tid = gl_LocalInvocationIndex;
|
||||||
|
|
||||||
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
|
|
||||||
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
|
|
||||||
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
|
|
||||||
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
||||||
float min_v = FLT_MAX_OVER_2;
|
float min_v = FLT_MAX_OVER_2;
|
||||||
float max_v = -FLT_MAX_OVER_2;
|
float max_v = -FLT_MAX_OVER_2;
|
||||||
|
|
@ -66,11 +49,21 @@ void main() {
|
||||||
j0 += (i0 * 16 + block_x) * Bc;
|
j0 += (i0 * 16 + block_x) * Bc;
|
||||||
j1 += i1 * Br;
|
j1 += i1 * Br;
|
||||||
|
|
||||||
|
if (!need_bounds_check || j0 + 3 < nem0) {
|
||||||
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
|
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
|
||||||
[[unroll]] for (int c = 0; c < 4; ++c) {
|
[[unroll]] for (int c = 0; c < 4; ++c) {
|
||||||
min_v = min(min_v, f[c]);
|
min_v = min(min_v, f[c]);
|
||||||
max_v = max(max_v, f[c]);
|
max_v = max(max_v, f[c]);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
[[unroll]] for (int c = 0; c < 4; ++c) {
|
||||||
|
if (j0 + c < nem0) {
|
||||||
|
float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
|
||||||
|
min_v = min(min_v, f);
|
||||||
|
max_v = max(max_v, f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
min_v = subgroupMin(min_v);
|
min_v = subgroupMin(min_v);
|
||||||
max_v = subgroupMax(max_v);
|
max_v = subgroupMax(max_v);
|
||||||
|
|
@ -93,6 +86,33 @@ void main() {
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each Br x Bc block of the mask (input) buffer, read all values and check
|
||||||
|
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
|
||||||
|
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
|
||||||
|
// 32-bit result mask.
|
||||||
|
//
|
||||||
|
// TODO: This is a lot of work per workgroup, might make sense to split this into
|
||||||
|
// more workgroups in the future.
|
||||||
|
void main() {
|
||||||
|
// Each workgroup handles a row
|
||||||
|
const uint tid = gl_LocalInvocationIndex;
|
||||||
|
const uint i0 = gl_WorkGroupID.x;
|
||||||
|
const uint i1 = gl_WorkGroupID.y;
|
||||||
|
const uint i2 = gl_WorkGroupID.z % nem2;
|
||||||
|
const uint i3 = gl_WorkGroupID.z / nem2;
|
||||||
|
|
||||||
|
uint result = 0;
|
||||||
|
|
||||||
|
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
|
||||||
|
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
|
||||||
|
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
|
||||||
|
if ((i0 + 1) * 16 * Bc <= nem0) {
|
||||||
|
loadvec4(result, i0, i1, i2, i3, false);
|
||||||
|
} else {
|
||||||
|
loadvec4(result, i0, i1, i2, i3, true);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
||||||
float min_v = FLT_MAX_OVER_2;
|
float min_v = FLT_MAX_OVER_2;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue