vulkan: fix OOB check in flash_attn_mask_opt (#20296)
This commit is contained in:
parent
5866e3bbc8
commit
aa429cf507
|
|
@ -8840,7 +8840,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
|
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
|
||||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
||||||
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,61 @@ layout (push_constant) uniform parameter {
|
||||||
shared float minsh[NUM_SUBGROUPS];
|
shared float minsh[NUM_SUBGROUPS];
|
||||||
shared float maxsh[NUM_SUBGROUPS];
|
shared float maxsh[NUM_SUBGROUPS];
|
||||||
|
|
||||||
|
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
|
||||||
|
|
||||||
|
void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
|
||||||
|
const uint tid = gl_LocalInvocationIndex;
|
||||||
|
|
||||||
|
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
||||||
|
float min_v = FLT_MAX_OVER_2;
|
||||||
|
float max_v = -FLT_MAX_OVER_2;
|
||||||
|
[[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
|
||||||
|
uint j0 = (i + tid) % (Bc / 4);
|
||||||
|
uint j1 = (i + tid) / (Bc / 4);
|
||||||
|
|
||||||
|
j0 *= 4;
|
||||||
|
j0 += (i0 * 16 + block_x) * Bc;
|
||||||
|
j1 += i1 * Br;
|
||||||
|
|
||||||
|
if (!need_bounds_check || j0 + 3 < nem0) {
|
||||||
|
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
|
||||||
|
[[unroll]] for (int c = 0; c < 4; ++c) {
|
||||||
|
min_v = min(min_v, f[c]);
|
||||||
|
max_v = max(max_v, f[c]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
[[unroll]] for (int c = 0; c < 4; ++c) {
|
||||||
|
if (j0 + c < nem0) {
|
||||||
|
float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
|
||||||
|
min_v = min(min_v, f);
|
||||||
|
max_v = max(max_v, f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
min_v = subgroupMin(min_v);
|
||||||
|
max_v = subgroupMax(max_v);
|
||||||
|
if (gl_SubgroupInvocationID == 0) {
|
||||||
|
minsh[gl_SubgroupID] = min_v;
|
||||||
|
maxsh[gl_SubgroupID] = max_v;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
if (tid == 0) {
|
||||||
|
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
|
||||||
|
min_v = min(min_v, minsh[i]);
|
||||||
|
max_v = max(max_v, maxsh[i]);
|
||||||
|
}
|
||||||
|
if (max_v <= -FLT_MAX_OVER_2) {
|
||||||
|
result |= 1 << (2*block_x);
|
||||||
|
}
|
||||||
|
if (min_v == 0.0f && max_v == 0.0f) {
|
||||||
|
result |= 2 << (2*block_x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For each Br x Bc block of the mask (input) buffer, read all values and check
|
// For each Br x Bc block of the mask (input) buffer, read all values and check
|
||||||
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
|
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
|
||||||
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
|
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
|
||||||
|
|
@ -48,50 +103,15 @@ void main() {
|
||||||
const uint i2 = gl_WorkGroupID.z % nem2;
|
const uint i2 = gl_WorkGroupID.z % nem2;
|
||||||
const uint i3 = gl_WorkGroupID.z / nem2;
|
const uint i3 = gl_WorkGroupID.z / nem2;
|
||||||
|
|
||||||
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
|
|
||||||
|
|
||||||
uint result = 0;
|
uint result = 0;
|
||||||
|
|
||||||
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
|
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
|
||||||
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
|
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
|
||||||
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
|
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
|
||||||
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
if ((i0 + 1) * 16 * Bc <= nem0) {
|
||||||
float min_v = FLT_MAX_OVER_2;
|
loadvec4(result, i0, i1, i2, i3, false);
|
||||||
float max_v = -FLT_MAX_OVER_2;
|
} else {
|
||||||
[[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
|
loadvec4(result, i0, i1, i2, i3, true);
|
||||||
uint j0 = (i + tid) % (Bc / 4);
|
|
||||||
uint j1 = (i + tid) / (Bc / 4);
|
|
||||||
|
|
||||||
j0 *= 4;
|
|
||||||
j0 += (i0 * 16 + block_x) * Bc;
|
|
||||||
j1 += i1 * Br;
|
|
||||||
|
|
||||||
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
|
|
||||||
[[unroll]] for (int c = 0; c < 4; ++c) {
|
|
||||||
min_v = min(min_v, f[c]);
|
|
||||||
max_v = max(max_v, f[c]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
min_v = subgroupMin(min_v);
|
|
||||||
max_v = subgroupMax(max_v);
|
|
||||||
if (gl_SubgroupInvocationID == 0) {
|
|
||||||
minsh[gl_SubgroupID] = min_v;
|
|
||||||
maxsh[gl_SubgroupID] = max_v;
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
|
|
||||||
min_v = min(min_v, minsh[i]);
|
|
||||||
max_v = max(max_v, maxsh[i]);
|
|
||||||
}
|
|
||||||
if (max_v <= -FLT_MAX_OVER_2) {
|
|
||||||
result |= 1 << (2*block_x);
|
|
||||||
}
|
|
||||||
if (min_v == 0.0f && max_v == 0.0f) {
|
|
||||||
result |= 2 << (2*block_x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue