Merge 32d504cd94 into 05fa625eac
This commit is contained in:
commit
f78eee6e4d
|
|
@ -401,21 +401,27 @@ enum FaCodePath {
|
|||
FA_COOPMAT1,
|
||||
FA_COOPMAT2,
|
||||
};
|
||||
enum FaRows {
|
||||
FA_ROWS_1,
|
||||
FA_ROWS_SMALL,
|
||||
FA_ROWS_LARGE,
|
||||
};
|
||||
|
||||
struct vk_fa_pipeline_state {
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
|
||||
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, FaRows rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
|
||||
: HSK(HSK), HSV(HSV), rows(rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
|
||||
|
||||
uint32_t HSK, HSV;
|
||||
bool small_rows, small_cache;
|
||||
FaRows rows;
|
||||
bool small_cache;
|
||||
FaCodePath path;
|
||||
bool aligned;
|
||||
bool f32acc;
|
||||
uint32_t flags;
|
||||
|
||||
bool operator<(const vk_fa_pipeline_state &b) const {
|
||||
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
|
||||
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
|
||||
return std::tie(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags) <
|
||||
std::tie(b.HSK, b.HSV, b.rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -2755,16 +2761,21 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
if (hsv >= 192) {
|
||||
return 2;
|
||||
} else if ((hsv | hsk) & 8 || small_cache) {
|
||||
return 4;
|
||||
} else {
|
||||
static uint32_t get_fa_scalar_num_rows(const vk_device& device, uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
|
||||
if (rows == FA_ROWS_1) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (
|
||||
rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache ||
|
||||
(device->architecture == AMD_GCN && hsk <= 64) ||
|
||||
(device->vendor_id == VK_VENDOR_ID_INTEL)
|
||||
) {
|
||||
return 8;
|
||||
}
|
||||
|
||||
return 16;
|
||||
}
|
||||
|
||||
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||
|
|
@ -2774,42 +2785,60 @@ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
|||
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
||||
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||
if (path == FA_COOPMAT2) {
|
||||
return flash_attention_num_small_rows;
|
||||
} else {
|
||||
return scalar_flash_attention_num_small_rows;
|
||||
static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) {
|
||||
return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR;
|
||||
}
|
||||
|
||||
static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path, FaRows rows) {
|
||||
if (fa_disable_subgroups(device, path)) {
|
||||
return 0xFFFFFFFF;
|
||||
}
|
||||
|
||||
if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && rows == FA_ROWS_1) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
return device->subgroup_size;
|
||||
}
|
||||
|
||||
static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) {
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path, rows);
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
|
||||
case FA_COOPMAT1:
|
||||
return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc
|
||||
default:
|
||||
if (subgroup_size > 32 && (Br < 4 || hsk < 64)) {
|
||||
return subgroup_size * 2;
|
||||
} else {
|
||||
return subgroup_size * 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
if (rows == FA_ROWS_1 || ((hsk|hsv) & 8)) {
|
||||
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
||||
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
||||
// But this only applies to row_split=1, meaning FA_ROWS_1
|
||||
return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 64};
|
||||
} else {
|
||||
if ((hsv | hsk) & 8) {
|
||||
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
||||
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
||||
} else {
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
|
||||
}
|
||||
return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 32};
|
||||
}
|
||||
}
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
||||
} else {
|
||||
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
||||
}
|
||||
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
||||
}
|
||||
|
||||
// small rows, large cols
|
||||
if (small_rows) {
|
||||
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
||||
if (rows != FA_ROWS_LARGE) {
|
||||
return {flash_attention_num_small_rows, 32};
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
|
|
@ -2823,8 +2852,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|||
return {64, 64};
|
||||
}
|
||||
|
||||
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
|
||||
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
|
||||
static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) {
|
||||
return fa_rows_cols(device, path, hsk, hsv, 0, type, rows, small_cache)[1];
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
|
|
@ -3191,76 +3220,75 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||
const uint32_t D = (hsk|hsv);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
|
||||
auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache);
|
||||
|
||||
uint32_t wg_size;
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
break;
|
||||
default:
|
||||
wg_size = scalar_flash_attention_workgroup_size;
|
||||
break;
|
||||
}
|
||||
const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]);
|
||||
const uint32_t subgroup_size = fa_subgroup_size(device, path, rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
uint32_t D_split = std::min(std::min(subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// Nvidia prefers shared memory use to load large tiles of K.
|
||||
// Nvidia prefers shared memory use to load large tiles of K/V.
|
||||
// Switch to loading from global memory when it would use too much shared memory.
|
||||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
|
||||
const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags};
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
|
||||
uint32_t HSK = fa.first.HSK; \
|
||||
uint32_t HSV = fa.first.HSV; \
|
||||
bool small_rows = fa.first.small_rows; \
|
||||
FaRows rows = fa.first.rows; \
|
||||
bool small_cache = fa.first.small_cache; \
|
||||
FaCodePath path = fa.first.path; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
uint32_t flags = fa.first.flags; \
|
||||
bool fa_ds = fa_disable_subgroups(device, path); \
|
||||
uint32_t fa_sgs = fa_subgroup_size(device, path, rows); \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
}
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
|
|
@ -4531,6 +4559,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
||||
static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
|
||||
|
||||
static vk_device ggml_vk_get_device(size_t idx) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
||||
|
|
@ -4747,6 +4776,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device->shader_core_count = sm_props.shaderSMCount;
|
||||
} else if (amd_shader_core_properties2) {
|
||||
device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
|
||||
} else {
|
||||
device->shader_core_count = 0;
|
||||
}
|
||||
|
|
@ -8391,21 +8422,29 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
const uint32_t acc_type_size = !fp32acc ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acc_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256;
|
||||
const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
|
@ -8416,7 +8455,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
|||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
|
||||
const auto rows_cols = fa_rows_cols(device, FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
|
|
@ -8534,6 +8573,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
const bool small_cache = nek1 < 1024;
|
||||
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
|
|
@ -8542,10 +8582,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
|
||||
max_gqa = get_fa_scalar_num_rows(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache);
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
max_gqa = flash_attention_num_small_rows;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
|
|
@ -8561,23 +8601,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
workgroups_y /= gqa_ratio;
|
||||
}
|
||||
|
||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||
|
||||
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
||||
// So use scalar instead.
|
||||
if (small_rows && path == FA_COOPMAT1) {
|
||||
path = FA_SCALAR;
|
||||
FaRows rows;
|
||||
if (N == 1) {
|
||||
rows = FA_ROWS_1;
|
||||
} else if (N <= (path == FA_COOPMAT2 ? flash_attention_num_small_rows : 8)) {
|
||||
rows = FA_ROWS_SMALL;
|
||||
} else {
|
||||
rows = FA_ROWS_LARGE;
|
||||
}
|
||||
|
||||
// scalar is faster than coopmat2 when N==1
|
||||
if (N == 1 && path == FA_COOPMAT2) {
|
||||
if (rows == FA_ROWS_1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||
if (path == FA_SCALAR &&
|
||||
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
|
||||
small_rows = true;
|
||||
// with large hsk/hsv, scalar path may need to use small rows to fit in shared memory
|
||||
if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) {
|
||||
rows = FA_ROWS_SMALL;
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
|
|
@ -8592,7 +8632,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
|
||||
uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, rows, small_cache);
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
// the "aligned" shader variant will forcibly align strides, for performance
|
||||
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
||||
|
|
@ -8602,8 +8642,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
aligned = false;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
|
@ -8623,7 +8661,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
(mask != nullptr ? 2 : 0) |
|
||||
(logit_softcap != 0 ? 4 : 0);
|
||||
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
|
|
@ -8645,22 +8683,36 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
uint32_t split_kv = KV;
|
||||
uint32_t split_k = 1;
|
||||
|
||||
// Intel Alchemist prefers more workgroups
|
||||
const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
|
||||
|
||||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
|
||||
|
||||
auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
GGML_ASSERT(Br == pipeline->wg_denoms[0]);
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead.
|
||||
// Must either be a single batch or be using gqa, we can't mix the two.
|
||||
if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
|
||||
// Try to run two workgroups per SM.
|
||||
if (gqa_ratio > 1 && workgroups_x <= Br) {
|
||||
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
} else if (gqa_ratio <= 1) {
|
||||
uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
|
||||
if (total_wgs_no_split < shader_core_count * 2) {
|
||||
split_k = shader_core_count * 2 / total_wgs_no_split;
|
||||
}
|
||||
}
|
||||
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
}
|
||||
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
|
||||
|
|
@ -8674,10 +8726,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
|
||||
const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
|
||||
|
||||
|
|
@ -8757,15 +8805,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
if (ctx->prealloc_split_k_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
|
||||
// We reuse workgroups_x to mean the number of splits, so we need to
|
||||
// cancel out the divide by wg_denoms[0].
|
||||
uint32_t dispatch_x;
|
||||
if (gqa_ratio > 1) {
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
dispatch_x = split_k * workgroups_x;
|
||||
} else {
|
||||
dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
|
||||
}
|
||||
|
||||
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
|
||||
// We only use split_k when group query attention is enabled, which means
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||
// cancel out the divide by wg_denoms[0].
|
||||
pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
|
||||
pc, { dispatch_x, workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
||||
|
|
@ -15392,6 +15446,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
|||
}
|
||||
}
|
||||
|
||||
static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
|
||||
VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
|
||||
|
||||
if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t device_id = props.properties.deviceID;
|
||||
|
||||
switch (device_id) {
|
||||
case 0x56A6: // A310
|
||||
return 6;
|
||||
case 0x5693: // A370M
|
||||
case 0x56A5: // A380
|
||||
case 0x56B1: // Pro A40/A50
|
||||
return 8;
|
||||
case 0x5697: // A530M
|
||||
return 12;
|
||||
case 0x5692: // A550M
|
||||
case 0x56B3: // Pro A60
|
||||
return 16;
|
||||
case 0x56A2: // A580
|
||||
return 24;
|
||||
case 0x5691: // A730M
|
||||
case 0x56A1: // A750
|
||||
return 28;
|
||||
case 0x56A0: // A770
|
||||
case 0x5690: // A770M
|
||||
return 32;
|
||||
case 0xE212: // Pro B50
|
||||
return 16;
|
||||
case 0xE20C: // B570
|
||||
return 18;
|
||||
case 0xE20B: // B580
|
||||
return 20;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// checks
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
|
|
@ -16068,7 +16162,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
ggml_vk_print_graph_origin(tensor, done);
|
||||
}
|
||||
|
||||
if (avg_err > 0.5 || std::isnan(avg_err)) {
|
||||
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
||||
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
|
||||
std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
|
||||
if (src0 != nullptr) {
|
||||
|
|
|
|||
|
|
@ -3,9 +3,13 @@
|
|||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
|
|
@ -15,8 +19,11 @@
|
|||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t row_split = (Br < 4 || HSK <= 64) ? 1 : 4;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
const uint32_t num_subgroups = WorkGroupSize / SubGroupSize;
|
||||
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
|
|
@ -27,20 +34,20 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
|||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions
|
||||
const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
|
||||
shared float tmpsh[tmpsh_size];
|
||||
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
|
||||
|
||||
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][HSK / 4];
|
||||
const uint32_t qf_stride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
|
||||
|
||||
const uint32_t D = HSK > HSV ? HSK : HSV;
|
||||
const uint32_t kvsh_stride = D / 4 + 1;
|
||||
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
|
|
@ -50,8 +57,13 @@ void main() {
|
|||
init_indices();
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
||||
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
|
||||
|
||||
|
|
@ -60,37 +72,37 @@ void main() {
|
|||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
vec4 Of[Br][HSV_per_thread / 4];
|
||||
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = vec4(0.0);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
float Lf[Br], Mf[Br];
|
||||
float Lf[rows_per_thread], Mf[rows_per_thread];
|
||||
|
||||
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = 0;
|
||||
Mf[r] = NEG_FLT_MAX_OVER_2;
|
||||
}
|
||||
|
||||
float slope[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
slope[r] = 1.0;
|
||||
ACC_TYPE slope[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
slope[r] = ACC_TYPE(1.0);
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -113,75 +125,141 @@ void main() {
|
|||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
uint32_t mask_opt_bits = 0;
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
if (MASK_ENABLE) {
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c * masksh_stride + r] = m;
|
||||
max_mask = max(max_mask, float(m));
|
||||
} else {
|
||||
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
ACC_TYPE Sf[rows_per_thread][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = ACC_TYPE(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc) {
|
||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 Q_cache[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
} else {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = 0.0;
|
||||
}
|
||||
}
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -189,89 +267,109 @@ void main() {
|
|||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (LOGIT_SOFTCAP) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
|
||||
Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
|
||||
|
||||
Sf[r][c] += slope[r]*mvf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
Moldf[r] = Mf[r];
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// P = e^(S - M)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
|
||||
}
|
||||
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||
|
||||
// Compute sum across row of P
|
||||
rowsumf[r] = 0.0;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowsumf[r] += Pf[r][c];
|
||||
}
|
||||
|
||||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = eMf[r] * Of[r][d];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
||||
}
|
||||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSV / 4);
|
||||
uint32_t c = (idx + tid) / (HSV / 4);
|
||||
if (c < Bc) {
|
||||
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FLOAT_TYPE Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
|
||||
Lf[r] += Pf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 Vf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] += Pf[r][c] * Vf;
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += ACC_TYPEV4(Pf[r] * Vf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
// prevent race on tmpsh
|
||||
|
|
@ -279,58 +377,108 @@ void main() {
|
|||
|
||||
// reduce across threads
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float rowmaxf, eMf;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = Mf[r];
|
||||
|
||||
tmpsh[tid] = Mf[r];
|
||||
// Compute max across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
|
||||
if (SubGroupSize != SUBGROUPS_DISABLED) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
|
||||
}
|
||||
if (row_split == 1) {
|
||||
// Reduce inside workgroup with shmem
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
|
||||
}
|
||||
barrier();
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpsh[tid] = rowmaxf;
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf = exp(Moldf - Mf[r]);
|
||||
float eMf = exp(Moldf - Mf[r]);
|
||||
|
||||
Lf[r] = eMf*Lf[r];
|
||||
|
||||
tmpsh[tid] = Lf[r];
|
||||
|
||||
// Compute sum across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
|
||||
if (SubGroupSize != SUBGROUPS_DISABLED) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Lf[r] += subgroupShuffleXor(Lf[r], s);
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
|
||||
}
|
||||
barrier();
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Lf[r] += tmpsh[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = eMf * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
tmpsh[tid] = Lf[r];
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
Of[r][d] += tmpshv4[tid + s];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[d_tid];
|
||||
barrier();
|
||||
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] = ACC_TYPE(eMf) * Of[r][d];
|
||||
|
||||
if (SubGroupSize != SUBGROUPS_DISABLED) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
Of[r][d] = tmpsh_accv4[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpsh_accv4[tid] = Of[r][d];
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
Of[r][d] += tmpsh_accv4[tid ^ s];
|
||||
tmpsh_accv4[tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpsh_accv4[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -338,33 +486,53 @@ void main() {
|
|||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
if (p.gqa_ratio > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (row < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (row < N) {
|
||||
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
const uint global_row = i * Br + row;
|
||||
|
||||
if (global_row < N) {
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
|
||||
}
|
||||
}
|
||||
|
||||
if (global_row < N && d_tid == 0 && col_tid == 0) {
|
||||
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
|
||||
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
|
@ -373,7 +541,7 @@ void main() {
|
|||
ms = exp(Mf[r] - sink);
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] *= ms;
|
||||
Of[r][d] *= ACC_TYPE(ms);
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
|
|
@ -383,39 +551,37 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
float Lfrcp[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
|
||||
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (row < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (i * Br + r < N) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (i * Br + row < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ layout (constant_id = 4) const uint32_t HSV = 32;
|
|||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
|
||||
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
|
||||
layout (constant_id = 8) const uint32_t SHMEM_STAGING = 0;
|
||||
layout (constant_id = 9) const uint32_t Flags = 0;
|
||||
|
||||
const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
|
|
@ -66,9 +66,12 @@ layout (push_constant) uniform parameter {
|
|||
#define SINK_ENABLE_BIT (1<<24)
|
||||
#define N_LOG2_MASK 0xFFFF
|
||||
|
||||
#define SUBGROUPS_DISABLED 0xFFFFFFFF
|
||||
|
||||
layout (binding = 4) readonly buffer S {float data_s[];};
|
||||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
|
||||
|
||||
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
|
||||
|
||||
|
|
@ -94,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16
|
|||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return k_packed.k_data_packed[a_offset + ib];
|
||||
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
|
||||
} else {
|
||||
return v_packed.v_data_packed[a_offset + ib];
|
||||
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -107,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
|
|
@ -115,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
|
|
@ -123,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
} else {
|
||||
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -189,10 +192,16 @@ void init_indices()
|
|||
KV = p.KV;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
// batch and split_k share gl_WorkGroupID.x
|
||||
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
if (p.gqa_ratio > 1) {
|
||||
i = 0;
|
||||
// batch and split_k share gl_WorkGroupID.x
|
||||
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
} else {
|
||||
gqa_iq1 = 0;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
i = gl_WorkGroupID.x / p.k_num;
|
||||
}
|
||||
} else if (p.gqa_ratio > 1) {
|
||||
i = 0;
|
||||
gqa_iq1 = gl_WorkGroupID.x;
|
||||
|
|
@ -244,3 +253,11 @@ void init_indices()
|
|||
// Bias applied to softmax to stay in fp16 range.
|
||||
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
|
||||
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
void gqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV / 4 + c;
|
||||
data_ov4[o_offset + offset] = D_TYPEV4(elems);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,15 +33,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
|||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
|
|
@ -54,10 +45,11 @@ shared f16vec4 Psh[Bc * psh_stride];
|
|||
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
|
||||
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
|
||||
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
|
||||
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
|
||||
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
|
||||
const uint vsh_stride = v_cols;
|
||||
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
|
||||
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
|
||||
|
||||
shared ACC_TYPE slope[Br];
|
||||
|
||||
|
|
@ -78,15 +70,15 @@ void main() {
|
|||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
|
||||
if ((HSK % 16) != 0) {
|
||||
if ((HSK % 16) != 0 || (HSV % 16) != 0) {
|
||||
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Br * qstride) {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kshstride) {
|
||||
ksh[i + tid] = f16vec4(0);
|
||||
[[unroll]] for (uint i = 0; i < Bc * kvsh_stride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kvsh_stride) {
|
||||
kvsh[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -153,22 +145,22 @@ void main() {
|
|||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
uint32_t mask_opt_bits = 0;
|
||||
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
|
||||
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
|
||||
mask_cache[idx] = f16vec4(0);
|
||||
}
|
||||
|
||||
if (MASK_ENABLE) {
|
||||
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
|
|
@ -231,24 +223,24 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (K_LOAD_SHMEM != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
if (SHMEM_STAGING != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK_pad / 4);
|
||||
uint32_t c = (idx + tid) / (HSK_pad / 4);
|
||||
if (c < Bc) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -262,7 +254,7 @@ void main() {
|
|||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
if (K_LOAD_SHMEM == 0) {
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK) {
|
||||
#endif
|
||||
|
|
@ -277,13 +269,13 @@ void main() {
|
|||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
#endif
|
||||
}
|
||||
|
||||
ksh[row * kshstride + col_vec] = K_Tf;
|
||||
kvsh[row * kvsh_stride + col_vec] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -295,8 +287,8 @@ void main() {
|
|||
if (KV_bounds_check || d * 16 + 16 > HSK)
|
||||
#endif
|
||||
{
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#if BLOCK_SIZE == 1
|
||||
else {
|
||||
|
|
@ -305,8 +297,8 @@ void main() {
|
|||
}
|
||||
#endif
|
||||
} else {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
|
@ -329,7 +321,7 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
if (MASK_ENABLE) {
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
|
|
@ -397,6 +389,29 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSV_pad / 4);
|
||||
uint32_t c = (idx + tid) / (HSV_pad / 4);
|
||||
if (c < Bc) {
|
||||
f16vec4 V_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
|
||||
|
||||
// Each subgroup handles HSV/4 columns
|
||||
|
|
@ -410,6 +425,7 @@ void main() {
|
|||
const uint v_total = v_rows * v_cols;
|
||||
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
|
||||
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
// For f16, only preload if not aligned
|
||||
if (KV_bounds_check) {
|
||||
|
|
@ -428,43 +444,52 @@ void main() {
|
|||
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
#endif
|
||||
} else {
|
||||
ksh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
|
||||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
if (!KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
}
|
||||
|
||||
// Store SfMat to sfsh and load into Of
|
||||
const uint osh_stride = row_split * MatBc / 4;
|
||||
const uint o_offset = gl_SubgroupID * MatBc / 4;
|
||||
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
if (hsv_offset < HSV_pad) {
|
||||
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
|
||||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (!KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
} else {
|
||||
const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
|
||||
coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
}
|
||||
|
||||
// Store SfMat to sfsh and load into Of
|
||||
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
|
|
@ -500,27 +525,48 @@ void main() {
|
|||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
if (p.gqa_ratio > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
const uint global_row = i * Br + row;
|
||||
|
||||
if (global_row < N) {
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
|
||||
}
|
||||
}
|
||||
|
||||
if (global_row < N && col_tid == 0) {
|
||||
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
|
||||
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -564,7 +610,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
|
|
@ -573,9 +619,7 @@ void main() {
|
|||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV / 4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
|
||||
}
|
||||
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -586,9 +630,7 @@ void main() {
|
|||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV / 4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
|
||||
}
|
||||
data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|||
return elem;
|
||||
}
|
||||
|
||||
// Store O values for non-GQA split_k. Rows are tokens, not heads.
|
||||
D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
|
||||
uint32_t global_row = i * Br + r;
|
||||
if (global_row < N && c < HSV) {
|
||||
uint32_t o_off = HSV * p.ne1
|
||||
* (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store L/M values for non-GQA split_k.
|
||||
ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
|
||||
uint32_t global_row = i * Br + r;
|
||||
if (global_row < N && c == 0) {
|
||||
uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
|
||||
+ p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
|
@ -290,13 +312,19 @@ void main() {
|
|||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
if (p.gqa_ratio > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
} else {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
|
||||
coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
|||
}
|
||||
|
||||
void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
|
||||
// matmul
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
// No coopmats
|
||||
|
|
@ -622,49 +620,63 @@ void process_shaders() {
|
|||
}
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
|
||||
if (f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
for (const bool& fp16 : {false, true}) {
|
||||
std::map<std::string, std::string> base_dict;
|
||||
if (fp16) {
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}};
|
||||
} else {
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
// flash attention
|
||||
for (const bool& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
|
||||
if (fp16 && f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
}
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
}
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
|
|
|
|||
Loading…
Reference in New Issue