This commit is contained in:
Ruben Ortlam 2026-03-15 22:04:41 +02:00 committed by GitHub
commit beed370100
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1037 additions and 0 deletions

View File

@ -220,6 +220,7 @@ option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks"
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF)
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
option(GGML_VULKAN_ENABLE_SLANG "ggml: enable Vulkan Slang shader compiler" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_WEBGPU "ggml: use WebGPU" OFF)

View File

@ -112,6 +112,13 @@ if (Vulkan_FOUND)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON)
endif()
if (GGML_VULKAN_ENABLE_SLANG)
add_compile_definitions(GGML_VULKAN_ENABLE_SLANG)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_ENABLE_SLANG=ON)
find_program(Vulkan_SLANGC_EXECUTABLE NAMES slangc REQUIRED)
message(STATUS "slangc found: ${Vulkan_SLANGC_EXECUTABLE}")
endif()
if (GGML_VULKAN_VALIDATE)
add_compile_definitions(GGML_VULKAN_VALIDATE)
endif()
@ -175,6 +182,12 @@ if (Vulkan_FOUND)
file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
if (GGML_VULKAN_ENABLE_SLANG)
list(APPEND _ggml_vk_shader_files
"${_ggml_vk_input_dir}/flash_attn.slang"
)
endif()
# Because external projects do not provide source-level tracking,
# the vulkan-shaders-gen sources need to be explicitly added to
# ensure that changes will cascade into shader re-generation.
@ -194,6 +207,10 @@ if (Vulkan_FOUND)
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})
if (GGML_VULKAN_ENABLE_SLANG)
set(_ggml_vk_slangc_arg --slangc ${Vulkan_SLANGC_EXECUTABLE})
endif()
foreach (file_full ${_ggml_vk_shader_files})
get_filename_component(file ${file_full} NAME)
set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp")
@ -203,6 +220,7 @@ if (Vulkan_FOUND)
DEPFILE ${_ggml_vk_target_cpp}.d
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
${_ggml_vk_slangc_arg}
--source ${file_full}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}

View File

@ -8840,6 +8840,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
#ifdef GGML_VULKAN_ENABLE_SLANG
if (tuning_params.path != FA_SCALAR) {
#endif
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
@ -8847,6 +8850,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
#ifdef GGML_VULKAN_ENABLE_SLANG
}
#endif
const uint32_t alignment = tuning_params.block_cols;
bool aligned = (KV % alignment) == 0 &&

View File

@ -23,6 +23,10 @@ if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
message(STATUS "Enabling shader debug info")
endif()
if (GGML_VULKAN_ENABLE_SLANG)
add_compile_definitions(GGML_VULKAN_ENABLE_SLANG)
message(STATUS "Enabling Slang")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)

View File

@ -0,0 +1,80 @@
module common;
[require(spirv, subgroup_basic)]
public uint WaveGetWaveIndex() {
__target_switch
{
case spirv:
return spirv_asm {
OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupId:uint);
};
}
}
public interface ISharedMemory<T, uint N> {
static vector<T, N> get(uint idx);
static void set(uint idx, vector<T, N> value);
}
public interface IReduceOp<T, uint N> {
static vector<T, N> combine(vector<T, N> a, vector<T, N> b);
}
public struct MaxOp<T: __BuiltinFloatingPointType, uint N> : IReduceOp<T, N> {
static vector<T, N> combine(vector<T, N> a, vector<T, N> b) { return max(a, b); }
}
public struct SumOp<T: __BuiltinArithmeticType, uint N> : IReduceOp<T, N> {
static vector<T, N> combine(vector<T, N> a, vector<T, N> b) { return a + b; }
}
public vector<T, N> reduce<T: __BuiltinType, uint N, Op: IReduceOp<T, N>, ShMem: ISharedMemory<T, N>>(vector<T, N> value, uint from, uint to, uint tid, uint subgroup_size, bool OLD_AMD_WINDOWS = false) {
if (subgroup_size > 0) {
const uint subgroup_id = WaveGetWaveIndex();
const uint lane_id = WaveGetLaneIndex();
const uint from_id = lane_id % from;
const uint subgroup_size = WaveGetLaneCount();
// Reduce with subgroup ops first
[unroll] for (uint s = from; s < min(to, subgroup_size); s *= 2) {
if (!OLD_AMD_WINDOWS) {
value = Op::combine(value, WaveReadLaneAt(value, lane_id ^ s));
} else if (T is half) {
// Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
// Shuffle full vec4 as workaround.
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
value = Op::combine(value, (WaveReadLaneAt(vector<float, N>((value as vector<half, N>).value), lane_id ^ s) as vector<T, N>).value);
}
}
if (to > subgroup_size) {
// Reduce inside workgroup with shmem
GroupMemoryBarrierWithGroupSync();
if (lane_id < from) {
ShMem.set(subgroup_id * from + from_id, value);
}
GroupMemoryBarrierWithGroupSync();
value = ShMem.get(from_id);
[unroll] for (uint s = 1; s < to / subgroup_size; ++s) {
value = Op::combine(value, ShMem.get(s * from + from_id));
}
}
} else {
const uint group_id = tid / to;
const uint group_tid = tid % to;
const uint from_id = tid % from;
GroupMemoryBarrierWithGroupSync();
ShMem.set(tid, value);
GroupMemoryBarrierWithGroupSync();
[unroll] for (int s = int(to) / 2; s >= from; s >>= 1) {
if (group_tid < s) {
ShMem.set(tid, Op::combine(ShMem.get(tid), ShMem.get(tid ^ s)));
}
GroupMemoryBarrierWithGroupSync();
}
value = ShMem.get(group_id * to + from_id);
}
return value;
}

View File

@ -0,0 +1,711 @@
import common;
import types;
import flash_attn_loader;
[vk::constant_id( 0)] const uint WorkGroupSize = 128;
[vk::constant_id( 1)] const uint Br = 1;
[vk::constant_id( 2)] const uint Bc = 32;
[vk::constant_id( 3)] const uint HSK = 32;
[vk::constant_id( 4)] const uint HSV = 32;
[vk::constant_id( 5)] const uint Clamp = 0;
[vk::constant_id( 6)] const uint D_split = 16;
[vk::constant_id( 7)] const uint row_split = 1;
[vk::constant_id( 8)] const uint SubGroupSize = 32;
[vk::constant_id( 9)] const uint SHMEM_STAGING = 0;
[vk::constant_id(10)] const uint Flags = 0;
[vk::constant_id(11)] const uint LIMIT_OCCUPANCY_SHMEM = 0;
static const bool USE_MASK_OPT = (Flags & 1) != 0;
static const bool MASK_ENABLE = (Flags & 2) != 0;
static const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
static const bool OLD_AMD_WINDOWS = (Flags & 8) != 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
static const uint HSK_pad = (HSK + 15) & ~15;
static const uint HSV_pad = (HSV + 15) & ~15;
static const bool KV_bounds_check = Clamp != 0;
struct PushConstants {
uint N;
uint KV;
uint ne1;
uint ne2;
uint ne3;
uint neq2;
uint neq3;
uint nek2;
uint nek3;
uint nev2;
uint nev3;
uint nem1;
uint nem2;
uint nem3;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint nb21;
uint nb22;
uint nb23;
float scale;
float max_bias;
float logit_softcap;
uint mask_n_head_log2;
float m0;
float m1;
uint gqa_ratio;
uint split_kv;
uint k_num;
};
[[vk::push_constant]] ConstantBuffer<PushConstants> p;
static const uint HSK_per_thread = HSK / D_split;
static const uint HSV_per_thread = HSV / D_split;
static const uint rows_per_thread = Br / row_split;
static const uint cols_per_iter = WorkGroupSize / D_split / row_split;
static const uint cols_per_thread = Bc / cols_per_iter;
static const uint num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
[vk::binding(0)] StructuredBuffer<vector<float, 4>> data_qv4;
[vk::binding(1)] StructuredBuffer<vector<half, 4>> data_kh;
[vk::binding(1)] StructuredBuffer<vector<float, 4>> data_kf;
[vk::binding(1)] StructuredBuffer<block_q8_0_packed16> data_kq8_0;
[vk::binding(1)] StructuredBuffer<block_q4_0_packed16> data_kq4_0;
[vk::binding(2)] StructuredBuffer<vector<half, 4>> data_vh;
[vk::binding(2)] StructuredBuffer<vector<float, 4>> data_vf;
[vk::binding(2)] StructuredBuffer<block_q8_0_packed16> data_vq8_0;
[vk::binding(2)] StructuredBuffer<block_q4_0_packed16> data_vq4_0;
[vk::binding(3)] StructuredBuffer<half> data_m;
[vk::binding(4)] StructuredBuffer<float> data_s;
[vk::binding(5)] RWStructuredBuffer<D_TYPE> data_o;
[vk::binding(5)] RWStructuredBuffer<vector<D_TYPE, 4>> data_ov4;
[vk::binding(6)] StructuredBuffer<uint> data_mask_opt;
// If SubGroupSize is set to 0 then only use shmem reductions
static const uint tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
groupshared float tmpsh[tmpsh_size];
struct ShMemFloat: ISharedMemory<float, 1> {
static float get(uint idx) {
return tmpsh[idx];
}
static void set(uint idx, float value) {
tmpsh[idx] = value;
}
}
groupshared vector<FLOAT, 4> tmpshv4[tmpsh_size];
struct ShMemFloat4: ISharedMemory<FLOAT, 4> {
static vector<FLOAT, 4> get(uint idx) {
return tmpshv4[idx];
}
static void set(uint idx, vector<FLOAT, 4> value) {
tmpshv4[idx] = value;
}
}
static const uint masksh_stride = Br + 1;
groupshared FLOAT masksh[Bc * masksh_stride];
static const uint qf_stride = HSK / 4 + 1;
groupshared vector<FLOAT, 4> Qf[Br * qf_stride];
static const uint D = HSK > HSV ? HSK : HSV;
static const uint kvsh_stride = D / 4 + 1;
groupshared vector<FLOAT, 4> kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
groupshared vector<float, 4> occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
struct Indices {
uint i, N, KV, split_k_index, Tr, start_j, end_j,
gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
};
#define SINK_ENABLE_BIT (1<<24)
#define N_LOG2_MASK 0xFFFF
#define MASK_OPT_ALL_NEG_INF 1
#define MASK_OPT_ALL_ZERO 2
uint ceil_div(uint a, uint b) {
return (a + b - 1) / b;
}
// Store column zero. This is used to save per-row m and L values for split_k.
T perElemOpStoreCol0<T: __BuiltinFloatingPointType>(const uint r, const uint32_t c, const T elem, const uint32_t o_offset, const uint32_t iq2, const uint32_t N)
{
if (r < N && c == 0) {
uint offset = iq2 + r;
data_o[o_offset + offset] = floatCast<D_TYPE>(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
T perElemOpComputeSlope<T: __BuiltinFloatingPointType>(const uint r, const uint32_t c, const T elem, const uint32_t iq2)
{
const uint h = iq2 + (r % p.gqa_ratio);
uint n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const T base = T(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return T(pow(base, T(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
T perElemOpGetSink<T: __BuiltinFloatingPointType>(const in uint r, const in uint32_t c, const in T elem, const in uint32_t iq2)
{
const uint h = iq2 + (r % p.gqa_ratio);
return T(data_s[h]);
}
Indices init_indices(const vector<uint, 3> wgid)
{
Indices idcs;
idcs.N = p.N;
idcs.KV = p.KV;
if (p.k_num > 1) {
if (p.gqa_ratio > 1) {
idcs.i = 0;
// batch and split_k share wgid.x
idcs.gqa_iq1 = wgid.x / p.k_num;
idcs.split_k_index = wgid.x % p.k_num;
} else {
idcs.gqa_iq1 = 0;
idcs.split_k_index = wgid.x % p.k_num;
idcs.i = wgid.x / p.k_num;
}
} else if (p.gqa_ratio > 1) {
idcs.i = 0;
idcs.gqa_iq1 = wgid.x;
idcs.split_k_index = 0;
} else {
idcs.i = wgid.x;
idcs.gqa_iq1 = 0;
idcs.split_k_index = 0;
}
idcs.Tr = ceil_div(idcs.N, Br);
idcs.start_j = idcs.split_k_index * p.split_kv / Bc;
idcs.end_j = ceil_div(min(idcs.KV, (idcs.split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to wgid.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
idcs.iq2 = wgid.y * p.gqa_ratio;
idcs.iq3 = wgid.z;
// broadcast factors
idcs.rk2 = p.neq2/p.nek2;
idcs.rk3 = p.neq3/p.nek3;
idcs.rv2 = p.neq2/p.nev2;
idcs.rv3 = p.neq3/p.nev3;
// k indices
idcs.ik3 = idcs.iq3 / idcs.rk3;
idcs.ik2 = idcs.iq2 / idcs.rk2;
// v indices
idcs.iv3 = idcs.iq3 / idcs.rv3;
idcs.iv2 = idcs.iq2 / idcs.rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
idcs.q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
idcs.k_stride = p.nb11;
idcs.v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
idcs.m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : idcs.KV;
return idcs;
}
#if defined(DATA_A_F32)
typealias KLoader = ScalarKVLoader<float>;
typealias VLoader = ScalarKVLoader<float>;
#elif defined(DATA_A_Q8_0)
typealias KLoader = Q8_0KVLoader;
typealias VLoader = Q8_0KVLoader;
#elif defined(DATA_A_Q4_0)
typealias KLoader = Q4_0KVLoader;
typealias VLoader = Q4_0KVLoader;
#else //if defined(DATA_A_F16)
typealias KLoader = ScalarKVLoader<half>;
typealias VLoader = ScalarKVLoader<half>;
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
void gqaStore<T: __BuiltinFloatingPointType>(const in uint32_t r, const in uint32_t c, const in vector<T, 4> elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) {
uint32_t offset = (iq2 + r) * HSV / 4 + c;
data_ov4[o_offset + offset] = vector<D_TYPE, 4>(elems);
}
[shader("compute")]
[numthreads(WorkGroupSize, 1, 1)]
void main(
vector<uint, 3> wgid : SV_GroupID,
uint tid : SV_GroupIndex,
) {
const Indices idcs = init_indices(wgid);
const uint subgroup_invocation_id = WaveGetLaneIndex();
const uint subgroup_id = tid / WaveGetLaneCount();
const uint threads_per_rowgroup = WorkGroupSize / row_split;
const uint row_tid = tid / threads_per_rowgroup;
const uint rowgroup_tid = tid % threads_per_rowgroup;
const uint d_tid = tid % D_split;
const uint col_tid = (tid % threads_per_rowgroup) / D_split;
if (LIMIT_OCCUPANCY_SHMEM > 0) {
// This just exists to avoid the occupancy_limiter array getting optimized out
occupancy_limiter[tid] = vector<float, 4>(tid);
GroupMemoryBarrierWithGroupSync();
if (all(occupancy_limiter[tid] == vector<float, 4>(99999.0))) {
data_ov4[0] = vector<D_TYPE, 4>(occupancy_limiter[tid]);
}
}
#define tile_row(r) (row_tid * rows_per_thread + (r))
uint q_offset = idcs.gqa_iq1*p.nb01 + (idcs.iq2*p.nb02 + idcs.iq3*p.nb03) / 4;
[unroll] for (uint idx = 0; idx < Br * HSK / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSK / 4);
uint r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
idcs.i * Br + r < idcs.N) {
Qf[r * qf_stride + d] = vector<FLOAT, 4>(data_qv4[q_offset / 4 + (idcs.i * Br + r) * idcs.q_stride / 4 + d] * p.scale);
}
}
GroupMemoryBarrierWithGroupSync();
vector<FLOAT, 4> Of[rows_per_thread][HSV_per_thread / 4];
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Of[r][d] = vector<FLOAT, 4>(0.0);
}
}
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = asfloat(0xFEFFFFFF);
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
ACC_TYPE slope[rows_per_thread];
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
slope[r] = ACC_TYPE(1.0);
}
// ALiBi
if (p.max_bias > 0.0f) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), idcs.iq2);
}
}
const uint mo_stride = ceil_div(idcs.KV, 16 * Bc);
// mo_offset will point to the tile starting at row i*Br and col 0
uint mo_offset = mo_stride * idcs.i;
uint k_offset = idcs.ik2*p.nb12 + idcs.ik3*p.nb13;
uint v_offset = idcs.iv2*p.nb22 + idcs.iv3*p.nb23;
#if defined(DATA_A_F32)
KLoader kloader = KLoader(data_kf, k_offset, idcs.k_stride);
VLoader vloader = VLoader(data_vf, v_offset, idcs.v_stride);
#elif defined(DATA_A_Q4_0)
KLoader kloader = KLoader(data_kq4_0, k_offset, idcs.k_stride);
VLoader vloader = VLoader(data_vq4_0, v_offset, idcs.v_stride);
#elif defined(DATA_A_Q8_0)
KLoader kloader = KLoader(data_kq8_0, k_offset, idcs.k_stride);
VLoader vloader = VLoader(data_vq8_0, v_offset, idcs.v_stride);
#else //if defined(DATA_A_F16)
KLoader kloader = KLoader(data_kh, k_offset, idcs.k_stride);
VLoader vloader = VLoader(data_vh, v_offset, idcs.v_stride);
#endif
uint m_offset = idcs.gqa_iq1*idcs.KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset += ((idcs.iq3 % p.nem3) * p.nem2 + (idcs.iq2 % p.nem2)) * p.nem1 * idcs.KV;
mo_offset += ((idcs.iq3 % p.nem3) * p.nem2 + (idcs.iq2 % p.nem2)) * ceil_div(p.nem1, Br) * mo_stride;
}
uint mask_opt = 0;
uint mask_opt_idx = ~0;
uint mask_opt_bits = 0;
[loop]
for (uint j = idcs.start_j; j < idcs.end_j; ++j) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * Br; idx += WorkGroupSize) {
uint c = (idx + tid) % Bc;
uint r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < idcs.KV) && (!nem1_bounds_check || idcs.i * Br + r < p.nem1)) {
FLOAT m = FLOAT(data_m[m_offset + (idcs.i * Br + r) * idcs.m_stride + (j * Bc + c)]);
masksh[c * masksh_stride + r] = m;
max_mask = max(max_mask, float(m));
} else {
masksh[c * masksh_stride + r] = FLOAT(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = WaveActiveAllTrue(max_mask <= NEG_FLT_MAX_OVER_2);
GroupMemoryBarrierWithGroupSync();
if (WaveIsFirstLane()) {
tmpsh[subgroup_id] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint s = 0; s < WaveGetNumWaves(); ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
ACC_TYPE Sf[rows_per_thread][cols_per_thread];
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = ACC_TYPE(0.0);
}
}
if (SHMEM_STAGING != 0) {
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSK / 4);
uint c = (idx + tid) / (HSK / 4);
if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) {
vector<FLOAT, 4> K_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
K_Tf = kloader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
}
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
[unroll] for (uint d = 0; d < HSK_per_thread / 4; ++d) {
vector<FLOAT, 4> Q_cache[rows_per_thread];
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
}
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= idcs.KV) {
continue;
}
vector<FLOAT, 4> K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
K_Tf = kloader.load(j * Bc + c * cols_per_iter + col_tid, d * D_split + d_tid);
}
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
}
}
}
} else {
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= idcs.KV) {
continue;
}
[unroll] for (uint d = 0; d < HSK_per_thread / 4; ++d) {
vector<FLOAT, 4> K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
K_Tf = kloader.load(j * Bc + c * cols_per_iter + col_tid, d * D_split + d_tid);
}
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
}
}
}
}
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[unroll] for (uint s = D_split / 2; s > 0; s >>= 1) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += WaveReadLaneAt(Sf[r][c], subgroup_invocation_id ^ s);
}
}
}
if (LOGIT_SOFTCAP) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
}
}
}
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
FLOAT mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
Sf[r][c] += slope[r]*mvf;
}
}
}
float eMf[rows_per_thread];
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= idcs.KV) {
continue;
}
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Of[r][d] = FLOAT(eMf[r]) * Of[r][d];
}
}
if (SHMEM_STAGING != 0) {
GroupMemoryBarrierWithGroupSync();
[unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) {
uint d = (idx + tid) % (HSV / 4);
uint c = (idx + tid) / (HSV / 4);
if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) {
vector<FLOAT, 4> V_Tf = vector<FLOAT, 4>(0);
if (!KV_bounds_check || j * Bc + c < idcs.KV) {
V_Tf = vloader.load(j * Bc + c, d);
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
GroupMemoryBarrierWithGroupSync();
}
[unroll] for (uint c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= idcs.KV) {
continue;
}
FLOAT Pf[rows_per_thread];
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Pf[r] = FLOAT(exp(float(Sf[r][c]) - Mf[r]));
Lf[r] += Pf[r];
}
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
vector<FLOAT, 4> Vf;
if (SHMEM_STAGING != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
Vf = vloader.load(j * Bc + c * cols_per_iter + col_tid, d * D_split + d_tid);
}
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Of[r][d] += vector<FLOAT, 4>(Pf[r] * Vf);
}
}
}
}
// prevent race on tmpsh
GroupMemoryBarrierWithGroupSync();
// reduce across threads
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
float rowmaxf = Mf[r];
// Compute max across the row
rowmaxf = reduce<float, 1, MaxOp<float, 1>, ShMemFloat>(rowmaxf, D_split, threads_per_rowgroup, tid, SubGroupSize);
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
float eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
// Compute sum across the row
Lf[r] = reduce<float, 1, SumOp<float, 1>, ShMemFloat>(Lf[r], D_split, threads_per_rowgroup, tid, SubGroupSize);
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = FLOAT(eMf) * Of[r][d];
Of[r][d] = reduce<FLOAT, 4, SumOp<FLOAT, 4>, ShMemFloat4>(Of[r][d], D_split, threads_per_rowgroup, tid, SubGroupSize, OLD_AMD_WINDOWS);
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint o_offset = HSV * p.ne1 * (idcs.split_k_index + p.k_num * (idcs.gqa_iq1 + p.ne2 * idcs.iq3)) / 4;
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < idcs.N) {
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, idcs.iq2, idcs.N);
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (idcs.split_k_index + p.k_num * (idcs.gqa_iq1 + p.ne2 * idcs.iq3));
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < idcs.N) {
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, idcs.iq2, idcs.N);
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, idcs.iq2, idcs.N);
}
}
} else {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = idcs.i * Br + row;
if (global_row < idcs.N) {
uint o_offset = HSV * p.ne1 * (idcs.split_k_index + p.k_num * (global_row + p.ne2 * idcs.iq3)) / 4;
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
data_ov4[o_offset + idcs.iq2 * HSV/4 + d * D_split + d_tid] = vector<D_TYPE, 4>(Of[r][d]);
}
}
if (global_row < idcs.N && d_tid == 0 && col_tid == 0) {
uint lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (idcs.split_k_index + p.k_num * (global_row + p.ne2 * idcs.iq3));
data_o[lm_offset + idcs.iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + idcs.iq2] = D_TYPE(Mf[r]);
}
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), idcs.iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= FLOAT(ms);
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[rows_per_thread];
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= FLOAT(Lfrcp[r]);
#if defined(FLOAT_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
#endif
}
}
uint o_offset = (idcs.gqa_iq1*p.ne1*HSV + idcs.iq3*p.ne2*p.ne1*HSV) / 4;
if (p.gqa_ratio > 1) {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < idcs.N) {
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, idcs.iq2, idcs.N);
}
}
}
} else {
[unroll] for (uint r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (idcs.i * Br + row < idcs.N) {
[unroll] for (uint d = 0; d < HSV_per_thread / 4; ++d) {
data_ov4[o_offset + (idcs.iq2 * HSV + (idcs.i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = vector<D_TYPE, 4>(Of[r][d]);
}
}
}
}
}

View File

@ -0,0 +1,80 @@
module flash_attn_loader;
import types;
public interface IKVLoader {
static const uint BYTE_SIZE;
public vector<FLOAT, 4> load(uint element_idx, uint head_dim4_idx);
}
public struct ScalarKVLoader<T: __BuiltinFloatingPointType> : IKVLoader {
static const uint BYTE_SIZE = sizeof(T) * 4;
StructuredBuffer<vector<T, 4>> buf;
uint offset;
uint stride4;
public __init(StructuredBuffer<vector<T, 4>> b, uint o, uint s) {
buf = b;
offset = o / BYTE_SIZE;
stride4 = s / 4;
}
public vector<FLOAT, 4> load(uint element_idx, uint head_dim4_idx) {
return vector<FLOAT, 4>(buf[offset + element_idx * stride4 + head_dim4_idx]);
}
}
public struct Q8_0KVLoader : IKVLoader {
static const uint BYTE_SIZE = sizeof(block_q8_0_packed16);
StructuredBuffer<block_q8_0_packed16> buf;
uint offset;
uint stride;
public __init(StructuredBuffer<block_q8_0_packed16> b, uint o, uint s) {
buf = b;
offset = o / BYTE_SIZE;
stride = s * QUANT_K_Q8_0;
}
public vector<FLOAT, 4> load(uint element_idx, uint head_dim4_idx) {
const uint coord = element_idx * stride + 4 * head_dim4_idx;
const uint ib = coord / QUANT_K_Q8_0;
const uint iqs = (coord % QUANT_K_Q8_0);
const vector<FLOAT, 2> v0 = vector<FLOAT, 2>(unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2])).xy); // vec4 used due to #12147
const vector<FLOAT, 2> v1 = vector<FLOAT, 2>(unpack_s8s32(int32_t(buf[offset + ib].qs[iqs / 2 + 1])).xy);
return FLOAT(buf[offset + ib].d) * vector<FLOAT, 4>(v0.x, v0.y, v1.x, v1.y);
}
}
public struct Q4_0KVLoader : IKVLoader {
static const uint BYTE_SIZE = sizeof(block_q4_0_packed16);
StructuredBuffer<block_q4_0_packed16> buf;
uint offset;
uint stride;
public __init(StructuredBuffer<block_q4_0_packed16> b, uint o, uint s) {
buf = b;
offset = o / BYTE_SIZE;
stride = s * QUANT_K_Q4_0;
}
public vector<FLOAT, 4> load(uint element_idx, uint head_dim4_idx) {
const uint coord = element_idx * stride + 4 * head_dim4_idx;
const uint ib = coord / QUANT_K_Q4_0;
const uint iqs = (coord % QUANT_K_Q4_0);
uint vui_lo = uint(buf[offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(buf[offset + ib].qs[(iqs & 0xF) / 2 + 1]);
const uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return FLOAT(buf[offset + ib].d) * (vector<FLOAT, 4>(FLOAT(vui_lo & 0xF), FLOAT((vui_lo >> 8) & 0xF),
FLOAT(vui_hi & 0xF), FLOAT((vui_hi >> 8) & 0xF)) - FLOAT(8.0f));
}
}

View File

@ -0,0 +1,30 @@
module types;
#ifdef FLOAT16
public typealias FLOAT = half;
#else
public typealias FLOAT = float;
#endif
public static const uint32_t QUANT_K_Q4_0 = 32;
public static const uint32_t QUANT_R_Q4_0 = 2;
public struct block_q4_0 {
public float16_t d;
public uint8_t qs[16];
};
public struct block_q4_0_packed16 {
public float16_t d;
public uint16_t qs[16/2];
};
public static const uint32_t QUANT_K_Q8_0 = 32;
public static const uint32_t QUANT_R_Q8_0 = 1;
public struct block_q8_0 {
public float16_t d;
public int8_t qs[32];
};
public struct block_q8_0_packed16 {
public float16_t d;
public int16_t qs[32/2];
};

View File

@ -37,6 +37,7 @@ std::vector<std::pair<std::string, std::string>> shader_fnames;
std::locale c_locale("C");
std::string GLSLC = "glslc";
std::string SLANGC = "slangc";
std::string input_filepath = "";
std::string output_dir = "/tmp";
std::string target_hpp = "";
@ -397,6 +398,79 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
}
}
void string_to_spv_slang_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
#ifdef _WIN32
std::vector<std::string> cmd = {SLANGC, "-target", "spirv", "-Wno-39001", ""\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
#else
std::vector<std::string> cmd = {SLANGC, "-target", "spirv", "-Wno-39001", in_path, "-o", out_path};
#endif
// disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
cmd.push_back("-O2");
}
if (dep_file) {
cmd.push_back("-depfile");
#ifdef _WIN32
cmd.push_back("\"" + target_cpp + ".d\"");
#else
cmd.push_back(target_cpp + ".d");
#endif
}
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
cmd.push_back("-g");
#endif
for (const auto& define : defines) {
cmd.push_back("-D" + define.first + "=" + define.second);
}
std::string command;
for (const auto& part : cmd) {
command += part + " ";
}
std::string stdout_str, stderr_str;
try {
// std::cout << "Executing command: ";
// for (const auto& part : cmd) {
// std::cout << part << " ";
// }
// std::cout << std::endl;
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
return;
}
if (dep_file) {
// replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt
std::string dep = read_binary_file(target_cpp + ".d", true);
if (!dep.empty()) {
size_t pos = dep.find(out_path);
if (pos != std::string::npos) {
dep.replace(pos, out_path.length(), target_cpp);
}
write_binary_file(target_cpp + ".d", dep);
}
}
std::lock_guard<std::mutex> guard(lock);
shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
}
}
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
std::map<std::string, std::string> result = a;
result.insert(b.begin(), b.end());
@ -423,6 +497,25 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s
// Don't write the same dep file from multiple processes
generate_dep_file = false;
}
void string_to_spv_slang(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_path = join_paths(output_dir, name + ".spv");
if (input_filepath == "") {
// No input source to compile, only generate header for all shaders
shader_fnames.push_back(std::pair(name, out_path));
return;
} else if (basename(input_filepath) != source) {
// Only compile shader variants matching the input filename
return;
}
compile_count_guard slot = acquire_compile_slot();
compiles.push_back(std::async(
string_to_spv_slang_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
// Don't write the same dep file from multiple processes
generate_dep_file = false;
}
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
@ -663,6 +756,7 @@ void process_shaders() {
#endif
}
#ifndef GGML_VULKAN_ENABLE_SLANG
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
@ -671,6 +765,16 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
}
#else
if (tname == "f16") {
string_to_spv_slang("flash_attn_f32_f16_" + tname, "flash_attn.slang",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv_slang("flash_attn_f32_f16_" + tname, "flash_attn.slang",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
}
#endif
}
}
}
@ -1194,6 +1298,9 @@ int main(int argc, char** argv) {
if (args.find("--glslc") != args.end()) {
GLSLC = args["--glslc"]; // Path to glslc
}
if (args.find("--slangc") != args.end()) {
SLANGC = args["--slangc"];
}
if (args.find("--source") != args.end()) {
input_filepath = args["--source"]; // The shader source file to compile
}