538 lines
20 KiB
C++
538 lines
20 KiB
C++
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
|
#define GGML_WEBGPU_SHADER_LIB_HPP
|
|
|
|
#include "ggml.h"
|
|
#include "pre_wgsl.hpp"
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#define GGML_WEBGPU_F16_SIZE_BYTES 2
|
|
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
|
#define GGML_WEBGPU_I32_SIZE_BYTES 4
|
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
|
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
|
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
|
|
|
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
|
|
|
|
struct ggml_webgpu_processed_shader {
|
|
std::string wgsl;
|
|
std::string variant;
|
|
void * decisions;
|
|
};
|
|
|
|
// Same hash combine function as in boost
|
|
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
|
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
}
|
|
|
|
/** FlashAttention */
|
|
|
|
struct ggml_webgpu_flash_attn_pipeline_key {
|
|
ggml_type kv_type;
|
|
uint32_t head_dim_qk;
|
|
uint32_t head_dim_v;
|
|
bool kv_direct;
|
|
bool has_mask;
|
|
bool has_sinks;
|
|
bool uses_logit_softcap;
|
|
|
|
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
|
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
|
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
|
uses_logit_softcap == other.uses_logit_softcap;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
|
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
|
size_t seed = 0;
|
|
ggml_webgpu_hash_combine(seed, key.kv_type);
|
|
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
|
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
|
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
return seed;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_flash_attn_shader_lib_context {
|
|
ggml_webgpu_flash_attn_pipeline_key key;
|
|
uint32_t sg_mat_m;
|
|
uint32_t sg_mat_n;
|
|
uint32_t sg_mat_k;
|
|
size_t wg_mem_limit_bytes;
|
|
uint32_t max_subgroup_size;
|
|
};
|
|
|
|
struct ggml_webgpu_flash_attn_shader_decisions {
|
|
uint32_t q_tile = 0;
|
|
uint32_t kv_tile = 0;
|
|
uint32_t wg_size = 0;
|
|
};
|
|
|
|
// This is exposed because it's necessary in supports_op
|
|
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|
uint32_t kv_tile,
|
|
uint32_t head_dim_qk,
|
|
uint32_t head_dim_v,
|
|
bool has_mask,
|
|
bool kv_direct) {
|
|
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
|
size_t f16_elems = 0;
|
|
size_t f32_elems = 0;
|
|
f16_elems += q_tile * head_dim_qk; // q_shmem
|
|
if (!kv_direct) {
|
|
f16_elems += kv_tile * max_head_dim; // kv_shmem
|
|
}
|
|
f16_elems += q_tile * head_dim_v; // o_shmem
|
|
if (has_mask) {
|
|
f16_elems += q_tile * kv_tile; // mask_shmem
|
|
}
|
|
f16_elems += q_tile * kv_tile; // inter_shmem
|
|
f32_elems += q_tile; // row_max_shmem
|
|
f32_elems += q_tile; // exp_sum_shmem
|
|
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
}
|
|
|
|
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
|
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
|
const size_t q_tile = context.sg_mat_m;
|
|
const size_t base_q_bytes =
|
|
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
|
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
size_t bytes_per_kv = 0;
|
|
if (!context.key.kv_direct) {
|
|
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
|
|
}
|
|
if (context.key.has_mask) {
|
|
bytes_per_kv += q_tile;
|
|
}
|
|
bytes_per_kv += q_tile;
|
|
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
|
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
|
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
|
}
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = "flash_attn";
|
|
|
|
switch (context.key.kv_type) {
|
|
case GGML_TYPE_F32:
|
|
defines.push_back("KV_F32");
|
|
break;
|
|
case GGML_TYPE_F16:
|
|
defines.push_back("KV_F16");
|
|
break;
|
|
case GGML_TYPE_Q4_0:
|
|
defines.push_back("KV_Q4_0");
|
|
break;
|
|
case GGML_TYPE_Q8_0:
|
|
defines.push_back("KV_Q8_0");
|
|
break;
|
|
default:
|
|
GGML_ABORT("Unsupported KV type for flash attention shader");
|
|
}
|
|
variant += std::string("_") + ggml_type_name(context.key.kv_type);
|
|
|
|
if (context.key.has_mask) {
|
|
defines.push_back("MASK");
|
|
variant += "_mask";
|
|
}
|
|
if (context.key.has_sinks) {
|
|
defines.push_back("SINKS");
|
|
variant += "_sinks";
|
|
}
|
|
if (context.key.uses_logit_softcap) {
|
|
defines.push_back("LOGIT_SOFTCAP");
|
|
variant += "_lgsc";
|
|
}
|
|
|
|
if (context.key.kv_direct) {
|
|
defines.push_back("KV_DIRECT");
|
|
variant += "_kvdirect";
|
|
}
|
|
|
|
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
|
|
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
|
|
|
|
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
|
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
|
// For now these are not part of the variant name
|
|
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
|
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
|
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
|
|
|
// Add chosen Q/KV tile sizes
|
|
uint32_t q_tile = context.sg_mat_m;
|
|
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
|
|
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
|
if (context.key.kv_direct) {
|
|
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
|
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
|
|
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
|
kv_tile -= context.sg_mat_n;
|
|
}
|
|
}
|
|
|
|
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
|
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
|
|
|
// workgroup size
|
|
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
|
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
|
|
decisions->q_tile = q_tile;
|
|
decisions->kv_tile = kv_tile;
|
|
decisions->wg_size = wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
|
|
/** Generic **/
|
|
|
|
struct ggml_webgpu_generic_shader_lib_context {
|
|
int vec4;
|
|
uint32_t max_wg_size;
|
|
};
|
|
|
|
struct ggml_webgpu_generic_shader_decisions {
|
|
uint32_t wg_size;
|
|
};
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_generic_shader_lib_context & context,
|
|
const std::string & base_variant) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = base_variant;
|
|
|
|
if (context.vec4) {
|
|
defines.push_back("VEC4");
|
|
variant += "_vec";
|
|
}
|
|
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
return result;
|
|
}
|
|
|
|
/** Pad **/
|
|
|
|
struct ggml_webgpu_pad_pipeline_key {
|
|
bool circular;
|
|
|
|
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
|
|
};
|
|
|
|
struct ggml_webgpu_pad_pipeline_key_hash {
|
|
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
|
|
size_t seed = 0;
|
|
ggml_webgpu_hash_combine(seed, key.circular);
|
|
return seed;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_pad_shader_lib_context {
|
|
ggml_webgpu_pad_pipeline_key key;
|
|
uint32_t max_wg_size;
|
|
};
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_pad_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = "pad";
|
|
|
|
if (context.key.circular) {
|
|
defines.push_back("CIRCULAR");
|
|
variant += "_circular";
|
|
}
|
|
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
|
decisions->wg_size = context.max_wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
|
|
/** Argsort **/
|
|
|
|
struct ggml_webgpu_argsort_shader_lib_context {
|
|
uint32_t max_wg_size;
|
|
size_t wg_mem_limit_bytes;
|
|
int32_t order;
|
|
};
|
|
|
|
struct ggml_webgpu_argsort_shader_decisions {
|
|
uint32_t wg_size = 0;
|
|
};
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_argsort_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = "argsort";
|
|
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
|
|
variant += std::string("_order") + std::to_string(context.order);
|
|
uint32_t wg_size = 1;
|
|
while (wg_size * 2 <= context.max_wg_size &&
|
|
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
|
|
wg_size *= 2;
|
|
}
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
|
|
decisions->wg_size = wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_argsort_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = "argsort_merge";
|
|
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
|
|
variant += std::string("_order") + std::to_string(context.order);
|
|
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
|
|
decisions->wg_size = wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
|
|
/** Set Rows **/
|
|
|
|
struct ggml_webgpu_set_rows_pipeline_key {
|
|
int dst_type;
|
|
int vec4;
|
|
int i64_idx;
|
|
|
|
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
|
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
|
|
size_t seed = 0;
|
|
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
ggml_webgpu_hash_combine(seed, key.vec4);
|
|
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
|
return seed;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_set_rows_shader_lib_context {
|
|
ggml_webgpu_set_rows_pipeline_key key;
|
|
uint32_t max_wg_size;
|
|
};
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_set_rows_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = "set_rows";
|
|
|
|
switch (context.key.dst_type) {
|
|
case GGML_TYPE_F32:
|
|
defines.push_back("DST_F32");
|
|
variant += "_dstf32";
|
|
break;
|
|
case GGML_TYPE_F16:
|
|
defines.push_back("DST_F16");
|
|
variant += "_dstf16";
|
|
break;
|
|
default:
|
|
GGML_ABORT("Unsupported dst type for set_rows shader");
|
|
}
|
|
|
|
if (context.key.vec4) {
|
|
defines.push_back("VEC4");
|
|
variant += "_vec";
|
|
}
|
|
if (context.key.i64_idx) {
|
|
defines.push_back("I64_IDX");
|
|
variant += "_i64idx";
|
|
}
|
|
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
|
decisions->wg_size = context.max_wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
|
|
struct ggml_webgpu_unary_pipeline_key {
|
|
int type;
|
|
int op;
|
|
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
|
bool inplace;
|
|
|
|
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
|
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_unary_pipeline_key_hash {
|
|
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
|
|
size_t seed = 0;
|
|
ggml_webgpu_hash_combine(seed, key.type);
|
|
ggml_webgpu_hash_combine(seed, key.op);
|
|
ggml_webgpu_hash_combine(seed, key.is_unary);
|
|
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
return seed;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_unary_shader_lib_context {
|
|
ggml_webgpu_unary_pipeline_key key;
|
|
uint32_t max_wg_size;
|
|
};
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_unary_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
|
|
ggml_op_name((ggml_op) context.key.op);
|
|
// Operation-specific behavior
|
|
defines.push_back(variant);
|
|
|
|
switch (context.key.type) {
|
|
case GGML_TYPE_F32:
|
|
defines.push_back("TYPE_F32");
|
|
variant += "_f32";
|
|
break;
|
|
case GGML_TYPE_F16:
|
|
defines.push_back("TYPE_F16");
|
|
variant += "_f16";
|
|
break;
|
|
default:
|
|
GGML_ABORT("Unsupported type for unary shader");
|
|
}
|
|
|
|
if (context.key.inplace) {
|
|
defines.push_back("INPLACE");
|
|
variant += "_inplace";
|
|
}
|
|
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
|
decisions->wg_size = context.max_wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
|
|
/** Binary **/
|
|
|
|
struct ggml_webgpu_binary_pipeline_key {
|
|
int type;
|
|
int op;
|
|
bool inplace;
|
|
bool overlap;
|
|
|
|
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
|
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_binary_pipeline_key_hash {
|
|
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
|
|
size_t seed = 0;
|
|
ggml_webgpu_hash_combine(seed, key.type);
|
|
ggml_webgpu_hash_combine(seed, key.op);
|
|
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
ggml_webgpu_hash_combine(seed, key.overlap);
|
|
return seed;
|
|
}
|
|
};
|
|
|
|
struct ggml_webgpu_binary_shader_lib_context {
|
|
ggml_webgpu_binary_pipeline_key key;
|
|
uint32_t max_wg_size;
|
|
};
|
|
|
|
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
|
|
pre_wgsl::Preprocessor & preprocessor,
|
|
const char * shader_src,
|
|
const ggml_webgpu_binary_shader_lib_context & context) {
|
|
std::vector<std::string> defines;
|
|
std::string op_name = ggml_op_name((ggml_op) context.key.op);
|
|
std::string variant = op_name;
|
|
|
|
defines.push_back(std::string("OP_") + op_name);
|
|
|
|
switch (context.key.type) {
|
|
case GGML_TYPE_F32:
|
|
defines.push_back("TYPE_F32");
|
|
variant += "_f32";
|
|
break;
|
|
case GGML_TYPE_F16:
|
|
defines.push_back("TYPE_F16");
|
|
variant += "_f16";
|
|
break;
|
|
default:
|
|
GGML_ABORT("Unsupported type for binary shader");
|
|
}
|
|
|
|
if (context.key.inplace) {
|
|
defines.push_back("INPLACE");
|
|
variant += "_inplace";
|
|
} else if (context.key.overlap) {
|
|
defines.push_back("OVERLAP");
|
|
variant += "_overlap";
|
|
}
|
|
|
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
ggml_webgpu_processed_shader result;
|
|
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
|
result.variant = variant;
|
|
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
|
decisions->wg_size = context.max_wg_size;
|
|
result.decisions = decisions;
|
|
return result;
|
|
}
|
|
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|