fix slang issues
This commit is contained in:
parent
a4ac1d903a
commit
e1b40fa53a
|
|
@ -8840,6 +8840,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
|
||||
#ifdef GGML_VULKAN_ENABLE_SLANG
|
||||
if (tuning_params.path != FA_SCALAR) {
|
||||
#endif
|
||||
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k_stride /= 4;
|
||||
|
|
@ -8847,6 +8850,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
if (v->type == GGML_TYPE_F32) {
|
||||
v_stride /= 4;
|
||||
}
|
||||
#ifdef GGML_VULKAN_ENABLE_SLANG
|
||||
}
|
||||
#endif
|
||||
|
||||
const uint32_t alignment = tuning_params.block_cols;
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
import types;
|
||||
import flash_attn_loader;
|
||||
|
||||
[vk::specialization_constant] const uint WorkGroupSize = 128;
|
||||
[vk::specialization_constant] const uint Br = 1;
|
||||
[vk::specialization_constant] const uint Bc = 32;
|
||||
[vk::specialization_constant] const uint HSK = 32;
|
||||
[vk::specialization_constant] const uint HSV = 32;
|
||||
[vk::specialization_constant] const uint Clamp = 0;
|
||||
[vk::specialization_constant] const uint D_split = 16;
|
||||
[vk::specialization_constant] const uint row_split = 1;
|
||||
[vk::specialization_constant] const uint SubGroupSize = 32;
|
||||
[vk::specialization_constant] const uint SHMEM_STAGING = 0;
|
||||
[vk::specialization_constant] const uint Flags = 0;
|
||||
[vk::specialization_constant] const uint LIMIT_OCCUPANCY_SHMEM = 0;
|
||||
[vk::constant_id( 0)] const uint WorkGroupSize = 128;
|
||||
[vk::constant_id( 1)] const uint Br = 1;
|
||||
[vk::constant_id( 2)] const uint Bc = 32;
|
||||
[vk::constant_id( 3)] const uint HSK = 32;
|
||||
[vk::constant_id( 4)] const uint HSV = 32;
|
||||
[vk::constant_id( 5)] const uint Clamp = 0;
|
||||
[vk::constant_id( 6)] const uint D_split = 16;
|
||||
[vk::constant_id( 7)] const uint row_split = 1;
|
||||
[vk::constant_id( 8)] const uint SubGroupSize = 32;
|
||||
[vk::constant_id( 9)] const uint SHMEM_STAGING = 0;
|
||||
[vk::constant_id(10)] const uint Flags = 0;
|
||||
[vk::constant_id(11)] const uint LIMIT_OCCUPANCY_SHMEM = 0;
|
||||
|
||||
static const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
static const bool MASK_ENABLE = (Flags & 2) != 0;
|
||||
|
|
@ -131,7 +131,7 @@ T perElemOpStoreCol0<T: __BuiltinFloatingPointType>(const uint r, const uint32_t
|
|||
{
|
||||
if (r < N && c == 0) {
|
||||
uint offset = iq2 + r;
|
||||
data_o[o_offset + offset] = (elem as D_TYPE).value;
|
||||
data_o[o_offset + offset] = floatCast<D_TYPE>(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
|
@ -240,10 +240,9 @@ typealias VLoader = ScalarKVLoader<half>;
|
|||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint32_t c, const in vector<T, 4> elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint32_t c, const in vector<T, 4> elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) {
|
||||
uint32_t offset = (iq2 + r) * HSV / 4 + c;
|
||||
data_ov4[o_offset + offset] = (elems as vector<D_TYPE, 4>).value;
|
||||
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
|
||||
}
|
||||
|
||||
[shader("compute")]
|
||||
|
|
@ -255,7 +254,7 @@ void main(
|
|||
const Indices idcs = init_indices(wgid);
|
||||
|
||||
const uint subgroup_invocation_id = WaveGetLaneIndex();
|
||||
const uint subgroup_id = tid / SubGroupSize;
|
||||
const uint subgroup_id = tid / WaveGetLaneCount();
|
||||
|
||||
const uint threads_per_rowgroup = WorkGroupSize / row_split;
|
||||
const uint row_tid = tid / threads_per_rowgroup;
|
||||
|
|
@ -386,7 +385,7 @@ void main(
|
|||
tmpsh[subgroup_id] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
[unroll] for (uint s = 0; s < num_subgroups; ++s) {
|
||||
[unroll] for (uint s = 0; s < WaveGetNumWaves(); ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
|
|
@ -752,8 +751,8 @@ void main(
|
|||
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= FLOAT(Lfrcp[r]);
|
||||
#if defined(FLOAT_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -FLOAT_MAX, FLOAT_MAX);
|
||||
#if defined(FLOAT_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ public struct ScalarKVLoader<T: __BuiltinFloatingPointType> : IKVLoader {
|
|||
}
|
||||
|
||||
public vector<FLOAT, 4> load(uint element_idx, uint head_dim4_idx) {
|
||||
return (buf[offset + element_idx * stride4 + head_dim4_idx] as vector<FLOAT, 4>).value;
|
||||
return vector<FLOAT, 4>(buf[offset + element_idx * stride4 + head_dim4_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -43,8 +43,8 @@ public struct Q8_0KVLoader : IKVLoader {
|
|||
const uint ib = coord / QUANT_K_Q8_0;
|
||||
const uint iqs = (coord % QUANT_K_Q8_0);
|
||||
|
||||
const vector<FLOAT, 2> v0 = (unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2])).xy as vector<FLOAT, 2>).value; // vec4 used due to #12147
|
||||
const vector<FLOAT, 2> v1 = (unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2 + 1])).xy as vector<FLOAT, 2>).value;
|
||||
const vector<FLOAT, 2> v0 = vector<FLOAT, 2>(unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2])).xy); // vec4 used due to #12147
|
||||
const vector<FLOAT, 2> v1 = vector<FLOAT, 2>(unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2 + 1])).xy);
|
||||
|
||||
return FLOAT(buf[offset + ib].d) * vector<FLOAT, 4>(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue