454 lines
17 KiB
Plaintext
454 lines
17 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
|
|
#ifdef FLOAT16
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
|
#endif
|
|
|
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
|
#extension GL_KHR_shader_subgroup_vote : enable
|
|
|
|
#include "types.glsl"
|
|
#include "flash_attn_base.glsl"
|
|
|
|
const uint32_t HSK_per_thread = HSK / D_split;
|
|
const uint32_t HSV_per_thread = HSV / D_split;
|
|
|
|
const uint32_t row_split = (Br < 4) ? 1 : 4;
|
|
const uint32_t rows_per_thread = Br / row_split;
|
|
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
|
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
|
const uint32_t num_subgroups = WorkGroupSize / SubGroupSize;
|
|
|
|
|
|
layout (binding = 0) readonly buffer Q {float data_q[];};
|
|
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
|
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
|
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
|
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
|
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
|
// Store the output when doing grouped query attention.
|
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
uint32_t offset = (iq2 + r) * HSV + c;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
return elem;
|
|
}
|
|
|
|
const uint32_t tmpsh_reduction_size = row_split == 1 ? num_subgroups * D_split : 0;
|
|
const uint32_t tmpsh_size = tmpsh_reduction_size > 4 ? tmpsh_reduction_size : 4;
|
|
shared float tmpsh[tmpsh_size];
|
|
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
|
|
|
|
shared FLOAT_TYPE masksh[Bc][Br];
|
|
|
|
const uint qfstride = HSK / 4 + 1;
|
|
shared FLOAT_TYPEV4 Qf[Br * qfstride];
|
|
|
|
void main() {
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
#endif
|
|
|
|
init_indices();
|
|
|
|
const uint32_t tid = gl_LocalInvocationIndex;
|
|
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
|
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
|
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
|
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
|
|
|
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
|
|
|
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
|
|
|
|
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
uint32_t d = (idx + tid) % (HSK / 4);
|
|
uint32_t r = (idx + tid) / (HSK / 4);
|
|
if (r < Br && d < HSK / 4 &&
|
|
i * Br + r < N) {
|
|
Qf[r * qfstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
|
}
|
|
}
|
|
barrier();
|
|
|
|
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Of[r][d] = ACC_TYPEV4(0.0);
|
|
}
|
|
}
|
|
|
|
float Lf[rows_per_thread], Mf[rows_per_thread];
|
|
|
|
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
|
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
|
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Lf[r] = 0;
|
|
Mf[r] = NEG_FLT_MAX_OVER_2;
|
|
}
|
|
|
|
ACC_TYPE slope[rows_per_thread];
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
slope[r] = ACC_TYPE(1.0);
|
|
}
|
|
|
|
// ALiBi
|
|
if (p.max_bias > 0.0f) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
|
|
}
|
|
}
|
|
|
|
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
|
|
// mo_offset will point to the tile starting at row i*Br and col 0
|
|
uint32_t mo_offset = mo_stride * i;
|
|
|
|
#if BLOCK_SIZE > 1
|
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
|
#else
|
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
#endif
|
|
uint32_t m_offset = gqa_iq1*KV;
|
|
if (p.nem2 != 1 || p.nem3 != 1) {
|
|
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
|
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
|
|
}
|
|
|
|
uint32_t mask_opt = 0;
|
|
uint32_t mask_opt_idx = ~0;
|
|
|
|
[[dont_unroll]]
|
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
|
mask_opt_idx = j / 16;
|
|
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
|
}
|
|
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
|
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
|
// skip this block
|
|
continue;
|
|
}
|
|
// Only load if the block is not all zeros
|
|
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
|
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
|
|
|
float max_mask = NEG_FLT_MAX_OVER_2;
|
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
uint32_t c = (idx + tid) % Bc;
|
|
uint32_t r = (idx + tid) / Bc;
|
|
if (idx + tid < Bc * Br) {
|
|
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
|
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
|
masksh[c][r] = m;
|
|
max_mask = max(max_mask, float(m));
|
|
} else {
|
|
masksh[c][r] = FLOAT_TYPE(0);
|
|
}
|
|
}
|
|
}
|
|
// skip the block if the mask is entirely -inf
|
|
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
|
barrier();
|
|
if (gl_SubgroupInvocationID == 0) {
|
|
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
|
}
|
|
barrier();
|
|
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
|
max_mask = max(max_mask, tmpsh[s]);
|
|
}
|
|
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
ACC_TYPE Sf[rows_per_thread][cols_per_thread];
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
Sf[r][c] = ACC_TYPE(0.0);
|
|
}
|
|
}
|
|
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
|
continue;
|
|
}
|
|
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
|
#if BLOCK_SIZE > 1
|
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
uint ib = coord / BLOCK_SIZE;
|
|
uint iqs = (coord % BLOCK_SIZE);
|
|
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
|
#else
|
|
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
|
#endif
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf));
|
|
}
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
// Compute sum across the D_split
|
|
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (LOGIT_SOFTCAP) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)];
|
|
|
|
Sf[r][c] += slope[r]*mvf;
|
|
}
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
FLOAT_TYPE Pf[rows_per_thread][cols_per_thread];
|
|
float eMf[rows_per_thread];
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
float rowmaxf = NEG_FLT_MAX_OVER_2;
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
|
continue;
|
|
}
|
|
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
|
}
|
|
float Moldf = Mf[r];
|
|
|
|
// M = max(rowmax, Mold)
|
|
// P = e^(S - M)
|
|
// eM = e^(Mold - M)
|
|
Mf[r] = max(rowmaxf, Moldf);
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
Pf[r][c] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
|
|
}
|
|
eMf[r] = exp(Moldf - Mf[r]);
|
|
|
|
// Compute sum across row of P
|
|
float rowsumf = 0.0;
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
|
continue;
|
|
}
|
|
rowsumf += Pf[r][c];
|
|
}
|
|
|
|
Lf[r] = eMf[r]*Lf[r] + rowsumf;
|
|
}
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
|
continue;
|
|
}
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
#if BLOCK_SIZE > 1
|
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
uint ib = coord / BLOCK_SIZE;
|
|
uint iqs = (coord % BLOCK_SIZE);
|
|
FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
|
|
#else
|
|
FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
|
#endif
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Of[r][d] += ACC_TYPEV4(Pf[r][c] * Vf);
|
|
}
|
|
}
|
|
}
|
|
|
|
barrier();
|
|
}
|
|
|
|
// prevent race on tmpsh
|
|
barrier();
|
|
|
|
// reduce across threads
|
|
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
float rowmaxf = Mf[r];
|
|
|
|
// Compute max across the row
|
|
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
|
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
|
|
}
|
|
if (row_split == 1) {
|
|
// Reduce inside workgroup with shmem
|
|
barrier();
|
|
if (gl_SubgroupInvocationID == d_tid) {
|
|
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
|
|
}
|
|
barrier();
|
|
rowmaxf = tmpsh[d_tid];
|
|
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
|
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
|
}
|
|
}
|
|
|
|
float Moldf = Mf[r];
|
|
|
|
// M = max(rowmax, Mold)
|
|
// eM = e^(Mold - M)
|
|
Mf[r] = max(rowmaxf, Moldf);
|
|
float eMf = exp(Moldf - Mf[r]);
|
|
|
|
Lf[r] = eMf*Lf[r];
|
|
|
|
// Compute sum across the row
|
|
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
|
Lf[r] += subgroupShuffleXor(Lf[r], s);
|
|
}
|
|
if (row_split == 1) {
|
|
barrier();
|
|
if (gl_SubgroupInvocationID == d_tid) {
|
|
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
|
|
}
|
|
barrier();
|
|
Lf[r] = tmpsh[d_tid];
|
|
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
|
Lf[r] += tmpsh[s * D_split + d_tid];
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
Of[r][d] = ACC_TYPE(eMf) * Of[r][d];
|
|
|
|
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
|
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
|
}
|
|
if (row_split == 1) {
|
|
barrier();
|
|
if (gl_SubgroupInvocationID == d_tid) {
|
|
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
|
|
}
|
|
barrier();
|
|
Of[r][d] = tmpsh_accv4[d_tid];
|
|
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
|
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// If there is split_k, then the split_k resolve shader does the final
|
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
if (p.k_num > 1) {
|
|
// note: O and Q have swapped coord 1,2.
|
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
|
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
const uint row = tile_row(r);
|
|
if (row < N) {
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
const uint row = tile_row(r);
|
|
if (row < N) {
|
|
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
|
}
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
|
|
|
|
float ms = 1.0f;
|
|
float vs = 1.0f;
|
|
|
|
if (sink > Mf[r]) {
|
|
ms = exp(Mf[r] - sink);
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
Of[r][d] *= ACC_TYPE(ms);
|
|
}
|
|
} else {
|
|
vs = exp(sink - Mf[r]);
|
|
}
|
|
|
|
Lf[r] = Lf[r]*ms + vs;
|
|
}
|
|
}
|
|
|
|
float Lfrcp[rows_per_thread];
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
|
}
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
|
#if defined(ACC_TYPE_MAX)
|
|
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
|
#endif
|
|
}
|
|
}
|
|
|
|
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
|
|
|
if (p.gqa_ratio > 1) {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
const uint row = tile_row(r);
|
|
if (row < N) {
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
const uint row = tile_row(r);
|
|
if (i * Br + row < N) {
|
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
data_o[o_offset + iq2 * HSV + (i * Br + row) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|