vulkan: Preprocess FA mask to detect all-neg-inf and all-zero. (#19281)
Write out a 2-bit code per block and avoid loading the mask when it matches these two common cases. Apply this optimization when the mask is relatively large (i.e. prompt processing).
This commit is contained in:
parent
3795cc1e89
commit
449ec2ab07
|
|
@ -402,18 +402,19 @@ enum FaCodePath {
|
|||
};
|
||||
|
||||
struct vk_fa_pipeline_state {
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
|
||||
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
|
||||
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
|
||||
|
||||
uint32_t HSK, HSV;
|
||||
bool small_rows, small_cache;
|
||||
FaCodePath path;
|
||||
bool aligned;
|
||||
bool f32acc;
|
||||
bool use_mask_opt;
|
||||
|
||||
bool operator<(const vk_fa_pipeline_state &b) const {
|
||||
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
|
||||
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
|
||||
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
|
||||
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -820,6 +821,8 @@ struct vk_device_struct {
|
|||
|
||||
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
|
||||
|
||||
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
vk_pipeline pipeline_count_experts;
|
||||
|
||||
|
|
@ -1549,6 +1552,18 @@ struct vk_op_flash_attn_split_k_reduce_push_constants {
|
|||
uint32_t sinks;
|
||||
};
|
||||
|
||||
struct vk_op_flash_attn_mask_opt_push_constants {
|
||||
uint32_t nem0;
|
||||
uint32_t nem1;
|
||||
uint32_t nem2;
|
||||
uint32_t nbm1;
|
||||
uint32_t nbm2;
|
||||
uint32_t nbm3;
|
||||
uint32_t nbd1;
|
||||
uint32_t nbd2;
|
||||
uint32_t nbd3;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
|
|
@ -1757,6 +1772,7 @@ class vk_perf_logger {
|
|||
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
|
||||
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
|
||||
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
|
||||
*n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
|
||||
return name.str();
|
||||
}
|
||||
if (node->op == GGML_OP_TOP_K) {
|
||||
|
|
@ -3177,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
|
|
@ -3209,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
|
|
@ -3221,18 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
FaCodePath path = fa.first.path; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
bool use_mask_opt = fa.first.use_mask_opt; \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
|
@ -4028,6 +4045,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
|
||||
for (auto &it : device->pipeline_fa_mask_opt) {
|
||||
auto BrBc = it.first;
|
||||
ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
|
||||
if (device->subgroup_clustered && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
|
||||
} else {
|
||||
|
|
@ -8400,8 +8422,6 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
|
||||
|
||||
const uint32_t qstride = hsk_pad / 4 + 2;
|
||||
const uint32_t Qf = Br * qstride * f16vec4;
|
||||
|
||||
|
|
@ -8418,7 +8438,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|||
|
||||
const uint32_t slope = Br * acctype;
|
||||
|
||||
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
|
||||
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
|
@ -8445,6 +8465,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const uint32_t nem0 = mask ? mask->ne[0] : 0;
|
||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
||||
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
||||
|
|
@ -8574,7 +8595,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
|
||||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
|
||||
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
|
|
@ -8625,10 +8649,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
{
|
||||
// Request descriptor sets
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
||||
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
|
||||
const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
|
||||
|
||||
vk_pipeline pipeline_fa_mask_opt = nullptr;
|
||||
if (use_mask_opt) {
|
||||
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
|
||||
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
|
||||
auto it = pipelines.find({Br, Bc});
|
||||
if (it != pipelines.end()) {
|
||||
pipeline_fa_mask_opt = it->second;
|
||||
} else {
|
||||
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
|
||||
}
|
||||
assert(pipeline_fa_mask_opt);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
|
||||
|
||||
if (ctx->prealloc_size_y < mask_opt_size) {
|
||||
ctx->prealloc_size_y = mask_opt_size;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -8655,9 +8701,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
|
||||
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
|
||||
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
|
||||
|
||||
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
|
||||
|
||||
if (use_mask_opt)
|
||||
{
|
||||
const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
|
||||
nem0,
|
||||
nem1,
|
||||
nem2,
|
||||
(uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
|
||||
(uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
|
||||
(uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
|
||||
mask_opt_num_dwords,
|
||||
mask_opt_num_dwords * CEIL_DIV(nem1, Br),
|
||||
mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
|
||||
{ mask_buf, mask_opt_buf }, opt_pc,
|
||||
{ mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
const vk_flash_attn_push_constants pc = { N, KV,
|
||||
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
||||
(uint32_t)neq2, (uint32_t)neq3,
|
||||
|
|
@ -8672,13 +8739,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
gqa_ratio, split_kv, split_k };
|
||||
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
||||
|
||||
if (ctx->prealloc_split_k_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
|
||||
// We only use split_k when group query attention is enabled, which means
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||
|
|
@ -8697,7 +8766,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
}
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
|
||||
pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -94,6 +94,10 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
|
||||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
|
|
@ -104,15 +108,28 @@ void main() {
|
|||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
|
||||
}
|
||||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
|
|
@ -120,25 +137,12 @@ void main() {
|
|||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
|
|
@ -185,7 +189,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
|
|
@ -256,9 +260,6 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
// prevent race on tmpsh
|
||||
barrier();
|
||||
|
||||
// reduce across threads
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
|
|||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
|
||||
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
|
||||
layout (constant_id = 9) const bool USE_MASK_OPT = false;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
|
|
@ -66,6 +67,11 @@ layout (binding = 4) readonly buffer S {float data_s[];};
|
|||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
|
||||
|
||||
#define MASK_OPT_ALL_NEG_INF 1
|
||||
#define MASK_OPT_ALL_ZERO 2
|
||||
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
|
|
|
|||
|
|
@ -42,8 +42,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|||
return elem;
|
||||
}
|
||||
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
|
|
@ -134,6 +132,10 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
|
||||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
|
|
@ -144,66 +146,74 @@ void main() {
|
|||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
|
||||
}
|
||||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
|
||||
mask_cache[idx] = f16vec4(0);
|
||||
}
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV)) {
|
||||
f16vec4 m;
|
||||
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
|
||||
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
|
||||
} else if (i * Br + r * 4 + 2 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
|
||||
0.0);
|
||||
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
|
||||
} else if (i * Br + r * 4 + 1 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
0.0,
|
||||
0.0);
|
||||
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
|
||||
} else if (i * Br + r * 4 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
0.0,
|
||||
0.0,
|
||||
0.0);
|
||||
max_mask = max(max_mask, float(m[0]));
|
||||
} else {
|
||||
m = f16vec4(0.0);
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV)) {
|
||||
f16vec4 m;
|
||||
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
|
||||
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
|
||||
} else if (i * Br + r * 4 + 2 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
|
||||
0.0);
|
||||
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
|
||||
} else if (i * Br + r * 4 + 1 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
|
||||
0.0,
|
||||
0.0);
|
||||
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
|
||||
} else if (i * Br + r * 4 < p.nem1) {
|
||||
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
|
||||
0.0,
|
||||
0.0,
|
||||
0.0);
|
||||
max_mask = max(max_mask, float(m[0]));
|
||||
} else {
|
||||
m = f16vec4(0.0);
|
||||
}
|
||||
mask_cache[idx / WorkGroupSize] = m;
|
||||
}
|
||||
mask_cache[idx / WorkGroupSize] = m;
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (K_LOAD_SHMEM != 0) {
|
||||
|
|
|
|||
|
|
@ -138,48 +138,53 @@ void main() {
|
|||
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
||||
}
|
||||
|
||||
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
|
||||
// mo_offset will point to the tile starting at row i*Br and col 0
|
||||
uint32_t mo_offset = mo_stride * i;
|
||||
|
||||
uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
|
||||
}
|
||||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,142 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
|
||||
layout (constant_id = 2) const uint Br = 32;
|
||||
layout (constant_id = 3) const uint Bc = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float16_t data_a[];};
|
||||
layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
|
||||
layout (binding = 1) writeonly buffer D {uint data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint nem0;
|
||||
uint nem1;
|
||||
uint nem2;
|
||||
uint nbm1;
|
||||
uint nbm2;
|
||||
uint nbm3;
|
||||
uint nbd1;
|
||||
uint nbd2;
|
||||
uint nbd3;
|
||||
};
|
||||
|
||||
#define MASK_OPT_ALL_NEG_INF 1
|
||||
#define MASK_OPT_ALL_ZERO 2
|
||||
|
||||
shared float minsh[NUM_SUBGROUPS];
|
||||
shared float maxsh[NUM_SUBGROUPS];
|
||||
|
||||
// For each Br x Bc block of the mask (input) buffer, read all values and check
|
||||
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
|
||||
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
|
||||
// 32-bit result mask.
|
||||
//
|
||||
// TODO: This is a lot of work per workgroup, might make sense to split this into
|
||||
// more workgroups in the future.
|
||||
void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint tid = gl_LocalInvocationIndex;
|
||||
const uint i0 = gl_WorkGroupID.x;
|
||||
const uint i1 = gl_WorkGroupID.y;
|
||||
const uint i2 = gl_WorkGroupID.z % nem2;
|
||||
const uint i3 = gl_WorkGroupID.z / nem2;
|
||||
|
||||
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
|
||||
|
||||
uint result = 0;
|
||||
|
||||
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
|
||||
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
|
||||
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
|
||||
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
||||
float min_v = FLT_MAX_OVER_2;
|
||||
float max_v = -FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
|
||||
uint j0 = (i + tid) % (Bc / 4);
|
||||
uint j1 = (i + tid) / (Bc / 4);
|
||||
|
||||
j0 *= 4;
|
||||
j0 += (i0 * 16 + block_x) * Bc;
|
||||
j1 += i1 * Br;
|
||||
|
||||
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
|
||||
[[unroll]] for (int c = 0; c < 4; ++c) {
|
||||
min_v = min(min_v, f[c]);
|
||||
max_v = max(max_v, f[c]);
|
||||
}
|
||||
}
|
||||
min_v = subgroupMin(min_v);
|
||||
max_v = subgroupMax(max_v);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
minsh[gl_SubgroupID] = min_v;
|
||||
maxsh[gl_SubgroupID] = max_v;
|
||||
}
|
||||
barrier();
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
|
||||
min_v = min(min_v, minsh[i]);
|
||||
max_v = max(max_v, maxsh[i]);
|
||||
}
|
||||
if (max_v <= -FLT_MAX_OVER_2) {
|
||||
result |= 1 << (2*block_x);
|
||||
}
|
||||
if (min_v == 0.0f && max_v == 0.0f) {
|
||||
result |= 2 << (2*block_x);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
|
||||
float min_v = FLT_MAX_OVER_2;
|
||||
float max_v = -FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
|
||||
if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
|
||||
continue;
|
||||
}
|
||||
uint j0 = (i + tid) % Bc;
|
||||
uint j1 = (i + tid) / Bc;
|
||||
|
||||
j0 += (i0 * 16 + block_x) * Bc;
|
||||
j1 += i1 * Br;
|
||||
|
||||
if (j0 < nem0 && j1 < nem1) {
|
||||
float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
|
||||
min_v = min(min_v, f);
|
||||
max_v = max(max_v, f);
|
||||
}
|
||||
}
|
||||
min_v = subgroupMin(min_v);
|
||||
max_v = subgroupMax(max_v);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
minsh[gl_SubgroupID] = min_v;
|
||||
maxsh[gl_SubgroupID] = max_v;
|
||||
}
|
||||
barrier();
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
|
||||
min_v = min(min_v, minsh[i]);
|
||||
max_v = max(max_v, maxsh[i]);
|
||||
}
|
||||
if (max_v <= -FLT_MAX_OVER_2) {
|
||||
result |= 1 << (2*block_x);
|
||||
}
|
||||
if (min_v == 0.0f && max_v == 0.0f) {
|
||||
result |= 2 << (2*block_x);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
|
||||
}
|
||||
}
|
||||
|
|
@ -790,6 +790,8 @@ void process_shaders() {
|
|||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
|
||||
|
||||
string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {});
|
||||
|
||||
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
||||
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
|
||||
|
||||
|
|
|
|||
|
|
@ -169,20 +169,22 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m
|
|||
const int blck0 = 128;
|
||||
const int blck1 = 64;
|
||||
|
||||
// number of INF blocks
|
||||
const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1);
|
||||
// number of INF/zero blocks
|
||||
const int n_inf_zero_blocks = 0.2*(ne0*ne1*ne2*ne3)/(blck0*blck1);
|
||||
|
||||
for (int b = 0; b < n_inf_blocks; b++) {
|
||||
for (int b = 0; b < n_inf_zero_blocks; b++) {
|
||||
const int p3 = (rd() % ne3);
|
||||
const int p2 = (rd() % ne2);
|
||||
const int p1 = (rd() % ne1);
|
||||
const int p0 = (rd() % ne0);
|
||||
|
||||
bool inf = rd() & 1;
|
||||
|
||||
for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) {
|
||||
const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;
|
||||
|
||||
for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) {
|
||||
data_f32[idx + i0] = -INFINITY;
|
||||
data_f32[idx + i0] = inf ? -INFINITY : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue