Compare commits

...

5 Commits

Author SHA1 Message Date
Dan Hoffman 237b876332
Merge 9f7ce433aa into 4951250235 2026-03-31 23:30:03 -07:00
Ed Addario 4951250235
llama : refactor llama_model_quantize_params to expose a pure C interface (#20346)
* Refactor llama_model_quantize_params to expose a pure C interface

* Restore comment and cleanup struct def

* Code review refactoring

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Code review refactoring

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-04-01 08:43:00 +03:00
Reese Levine 82764c341a
ggml webgpu: quantized buffers to u32 + wider browser/device support (#21046)
* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs
2026-04-01 08:38:24 +03:00
Abhijit Ramesh 825eb91a66
ggml-webgpu: port all AOT operators to JIT (#20728)
* port cpy pipeline to shader lib with JIT compilation
 * port glu pipeline to shader lib with JIT compilation
 * port rope pipeline to shader lib with JIT compilation
 * port soft_max pipeline to shader lib with JIT compilation
 * removed unused functions from embed_wgsl.py which were used for
old AOT template expansion
2026-03-31 15:38:16 -07:00
Dan Hoffman 9f7ce433aa Fix undefined timing measurement errors in server context 2026-03-30 21:09:31 -07:00
17 changed files with 1044 additions and 1282 deletions

View File

@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions {
uint32_t mul_mat_wg_size;
};
/** Cpy **/
struct ggml_webgpu_cpy_pipeline_key {
ggml_type src_type;
ggml_type dst_type;
bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
return src_type == other.src_type && dst_type == other.dst_type;
}
};
struct ggml_webgpu_cpy_pipeline_key_hash {
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
return seed;
}
};
/** Glu **/
struct ggml_webgpu_glu_pipeline_key {
ggml_glu_op glu_op;
ggml_type type;
bool split;
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
return glu_op == other.glu_op && type == other.type && split == other.split;
}
};
struct ggml_webgpu_glu_pipeline_key_hash {
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.glu_op);
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.split);
return seed;
}
};
/** Rope **/
struct ggml_webgpu_rope_pipeline_key {
ggml_type type;
bool inplace;
bool has_ff;
bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
}
};
struct ggml_webgpu_rope_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.has_ff);
return seed;
}
};
/** SoftMax **/
struct ggml_webgpu_soft_max_pipeline_key {
ggml_type mask_type;
bool has_mask;
bool has_sink;
bool inplace;
bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
inplace == other.inplace;
}
};
struct ggml_webgpu_soft_max_pipeline_key_hash {
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.mask_type);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sink);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
rope_pipelines;
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
soft_max_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -1124,9 +1219,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
break;
}
}
@ -1239,9 +1333,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
// Use f16 inside the shader for quantized types
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
variant += std::string("_") + src0_name;
break;
@ -1679,6 +1772,236 @@ class ggml_webgpu_shader_lib {
return flash_attn_pipelines[key];
}
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_cpy_pipeline_key key = {
.src_type = context.src0->type,
.dst_type = context.dst->type,
};
auto it = cpy_pipelines.find(key);
if (it != cpy_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "cpy";
switch (key.src_type) {
case GGML_TYPE_F32:
defines.push_back("SRC_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src type for cpy shader");
}
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("DST_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported dst type for cpy shader");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_cpy, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
cpy_pipelines[key] = pipeline;
return cpy_pipelines[key];
}
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_glu_pipeline_key key = {
.glu_op = ggml_get_glu_op(context.dst),
.type = context.dst->type,
.split = (context.src1 != nullptr),
};
auto it = glu_pipelines.find(key);
if (it != glu_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "glu";
switch (key.glu_op) {
case GGML_GLU_OP_REGLU:
defines.push_back("OP_REGLU");
variant += "_reglu";
break;
case GGML_GLU_OP_GEGLU:
defines.push_back("OP_GEGLU");
variant += "_geglu";
break;
case GGML_GLU_OP_SWIGLU:
defines.push_back("OP_SWIGLU");
variant += "_swiglu";
break;
case GGML_GLU_OP_SWIGLU_OAI:
defines.push_back("OP_SWIGLU_OAI");
variant += "_swiglu_oai";
break;
case GGML_GLU_OP_GEGLU_ERF:
defines.push_back("OP_GEGLU_ERF");
variant += "_geglu_erf";
break;
case GGML_GLU_OP_GEGLU_QUICK:
defines.push_back("OP_GEGLU_QUICK");
variant += "_geglu_quick";
break;
default:
GGML_ABORT("Unsupported GLU op");
}
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for GLU shader");
}
if (key.split) {
variant += "_split";
} else {
defines.push_back("NO_SPLIT");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_glu, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
glu_pipelines[key] = pipeline;
return glu_pipelines[key];
}
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {
.type = context.dst->type,
.inplace = context.inplace,
.has_ff = (context.src2 != nullptr),
};
auto it = rope_pipelines.find(key);
if (it != rope_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "rope";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for ROPE shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
if (key.has_ff) {
defines.push_back("FF_FUNC");
variant += "_ff";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_rope, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rope_pipelines[key] = pipeline;
return rope_pipelines[key];
}
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_soft_max_pipeline_key key = {
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
.has_mask = (context.src1 != nullptr),
.has_sink = (context.src2 != nullptr),
.inplace = context.inplace,
};
auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "soft_max";
if (key.has_mask) {
defines.push_back("HAS_MASK");
switch (key.mask_type) {
case GGML_TYPE_F32:
defines.push_back("MASK_F32");
variant += "_mask_f32";
break;
case GGML_TYPE_F16:
defines.push_back("MASK_F16");
variant += "_mask_f16";
break;
default:
GGML_ABORT("Unsupported type for SOFT_MAX shader");
}
}
if (key.has_sink) {
defines.push_back("HAS_SINK");
variant += "_sink";
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
soft_max_pipelines[key] = pipeline;
return soft_max_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,

View File

@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
@ -171,6 +171,7 @@ struct webgpu_buf_pool {
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
@ -364,13 +365,6 @@ struct webgpu_context_struct {
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
size_t memset_bytes_per_thread;
};
@ -514,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
while (blocking_wait) {
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
@ -735,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
params, entries, wg_x);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = ggml_webgpu_tensor_equal(src0, dst),
};
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_freq_factor = (src2 != nullptr);
@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int split = (src1 != nullptr);
std::vector<uint32_t> params = {
@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
const int has_sink = (src2 != nullptr);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = ggml_webgpu_tensor_equal(src0, dst),
};
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_mask = (src1 != nullptr);
const int has_sink = (src2 != nullptr);
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
has_mask ? (uint32_t) src1->ne[2] : 0,
has_mask ? (uint32_t) src1->ne[3] : 0,
*(uint32_t *) dst->op_params, // scale
*(uint32_t *) &max_bias,
*(uint32_t *) &n_head_log2,
@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
};
uint32_t binding_num = 1;
if (mask_type < 2) {
if (has_mask) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
ggml_nrows(dst));
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@ -2661,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
total_offset + (size - remaining_size), remaining_size);
} else {
// wait for WriteBuffer to complete
buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
}
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
}
@ -2885,139 +2907,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
// REGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
// GEGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
// SWIGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
// SWIGLU_OAI
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
// GEGLU_ERF
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
// GEGLU_QUICK
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
}
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
// f32 (no mask)
webgpu_ctx->soft_max_pipelines[2][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
// f32 mask (mask_type = 0)
webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
webgpu_ctx->soft_max_pipelines[0][1][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
"soft_max_f32_mask_f32_sink_inplace", constants);
// f16 mask (mask_type = 1)
webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
webgpu_ctx->soft_max_pipelines[1][1][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
"soft_max_f32_mask_f16_sink_inplace", constants);
}
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
wgpu::RequestAdapterOptions options = {};
@ -3183,10 +3072,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,

View File

@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
}
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_src0_u16_at(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_src0_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = src0[word_idx];
if (shift == 0u) {
return lo;
}
let hi = src0[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_src0_f16_at(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
return f16(packed[0]);
}
#endif
#ifdef Q4_0_T
struct q4_0 {
d: f16,

View File

@ -1,66 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "i32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f32"
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#ifdef SRC_F32
#define SRC_TYPE f32
#elif defined(SRC_F16)
#define SRC_TYPE f16
#endif
#ifdef DST_F32
#define DST_TYPE f32
#elif defined(DST_F16)
#define DST_TYPE f16
#elif defined(DST_I32)
#define DST_TYPE i32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{SRC_TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
struct Params{
ne: u32,
offset_src: u32,
offset_dst: u32,
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
@ -73,8 +48,7 @@ struct Params {
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
@ -102,6 +76,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
}
#end(SHADER)

View File

@ -1,41 +1,8 @@
import os
import re
import ast
import argparse
def extract_block(text, name):
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
match = re.search(pattern, text, re.DOTALL)
if not match:
raise ValueError(f"Missing block: {name}")
return match.group(1).strip()
def parse_decls(decls_text):
decls = {}
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
decls[name.strip()] = code.strip()
return decls
def replace_repl_placeholders(variant, template_map):
for repl, code in variant["REPLS"].items():
for key, val in template_map.items():
# Match "key" and avoid matching subsequences using by using \b
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
variant["REPLS"][repl] = code
return variant
def replace_placeholders(shader_text, replacements):
for key, val in replacements.items():
# Match {{KEY}} literally, where KEY is escaped
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
shader_text = re.sub(pattern, str(val), shader_text)
return shader_text
def expand_includes(shader, input_dir):
"""
Replace #include "file" lines in the text with the contents of that file.
@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
def generate_variants(fname, input_dir, output_dir, outfile):
shader_path = os.path.join(input_dir, fname)
shader_base_name = fname.split(".")[0]
with open(shader_path, "r", encoding="utf-8") as f:
text = f.read()
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
except ValueError:
decls_map = {}
try:
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
except ValueError:
templates_map = {}
for fname in sorted(os.listdir(input_dir)):
if fname.endswith(".tmpl"):
tmpl_path = os.path.join(input_dir, fname)
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
decls = f_tmpl.read()
decls_map.update(parse_decls(decls))
shader_template = extract_block(text, "SHADER")
for variant in variants:
if "DECLS" in variant:
decls = variant["DECLS"]
else:
decls = []
decls_code = ""
for key in decls:
if key not in decls_map:
raise ValueError(f"DECLS key '{key}' not found.")
decls_code += decls_map[key] + "\n\n"
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
if "REPLS" in variant:
variant = replace_repl_placeholders(variant, templates_map)
final_shader = replace_placeholders(final_shader, variant["REPLS"])
# second run to expand placeholders in repl_template
final_shader = replace_placeholders(final_shader, variant["REPLS"])
final_shader = expand_includes(final_shader, input_dir)
if "SHADER_NAME" in variant:
output_name = variant["SHADER_NAME"]
elif "SHADER_SUFFIX" in variant:
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_file", required=True)
parser.add_argument("--output_dir")
args = parser.parse_args()
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(fname, args.input_dir, args.output_dir, out)
shader_path = os.path.join(args.input_dir, fname)
shader_name = fname.replace(".wgsl", "")
with open(shader_path, "r", encoding="utf-8") as f:
shader_code = f.read()
write_shader(shader_name, shader_code, None, out, args.input_dir)
if __name__ == "__main__":

View File

@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
#define KV_TYPE u32
#else
#define KV_TYPE f16
#endif
@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix;
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
#define BLOCK_SIZE_BYTES 18u
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
#define BLOCK_SIZE_BYTES 34u
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#if defined(KV_Q4_0) || defined(KV_Q8_0)
fn load_k_u16_at(byte_offset: u32) -> u32 {
let word = K[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_k_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = K[word_idx];
if (shift == 0u) {
return lo;
}
let hi = K[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_v_u16_at(byte_offset: u32) -> u32 {
let word = V[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_v_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = V[word_idx];
if (shift == 0u) {
return lo;
}
let hi = V[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn f16_from_u16(bits: u32) -> f16 {
let packed = unpack2x16float(bits);
return f16(packed[0]);
}
#endif
struct Params {
offset_q: u32,
offset_k: u32,
@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;

View File

@ -1,323 +0,0 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "reglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "geglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "swiglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_oai_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "swiglu_oai_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "geglu_erf_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_quick_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(REGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return max(a, 0) * b;
}
#enddecl(REGLU)
#decl(GEGLU)
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
const GELU_COEF_A: {{TYPE}} = 0.044715;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
}
#enddecl(GEGLU)
#decl(SWIGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a / (1.0 + exp(-a)) * b;
}
#enddecl(SWIGLU)
#decl(SWIGLU_OAI)
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#enddecl(SWIGLU_OAI)
#decl(GEGLU_ERF)
const p_erf: {{TYPE}} = 0.3275911;
const a1_erf: {{TYPE}} = 0.254829592;
const a2_erf: {{TYPE}} = -0.284496736;
const a3_erf: {{TYPE}} = 1.421413741;
const a4_erf: {{TYPE}} = -1.453152027;
const a5_erf: {{TYPE}} = 1.061405429;
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#enddecl(GEGLU_ERF)
#decl(GEGLU_QUICK)
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#enddecl(GEGLU_QUICK)
#decl(NO_SPLIT)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#enddecl(NO_SPLIT)
#decl(SPLIT)
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
return src0[base];
}
fn b_value(base: u32) -> {{TYPE}} {
return src1[base];
}
#enddecl(SPLIT)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}
#end(SHADER)

View File

@ -0,0 +1,155 @@
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
#ifdef OP_REGLU
fn op(a: DataType, b: DataType) -> DataType {
return max(a, 0) * b;
}
#endif
#ifdef OP_GEGLU
const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
const GELU_COEF_A: DataType = 0.044715;
fn op(a: DataType, b: DataType) -> DataType {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
}
#endif
#ifdef OP_SWIGLU
fn op(a: DataType, b: DataType) -> DataType {
return a / (1.0 + exp(-a)) * b;
}
#endif
#ifdef OP_SWIGLU_OAI
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#endif
#ifdef OP_GEGLU_ERF
const p_erf: DataType = 0.3275911;
const a1_erf: DataType = 0.254829592;
const a2_erf: DataType = -0.284496736;
const a3_erf: DataType = 1.421413741;
const a4_erf: DataType = -1.453152027;
const a5_erf: DataType = 1.061405429;
const SQRT_2_INV: DataType = 0.7071067811865476;
fn op(a: DataType, b: DataType) -> DataType {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#endif
#ifdef OP_GEGLU_QUICK
const GELU_QUICK_COEF: DataType = -1.702;
fn op(a: DataType, b: DataType) -> DataType {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
#ifdef NO_SPLIT
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> DataType {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#else
@group(0) @binding(1)
var<storage, read_write> src1: array<DataType>;
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
return src0[base];
}
fn b_value(base: u32) -> DataType {
return src1[base];
}
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}

View File

@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = load_src0_f16_at(block_byte_base);
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 42u;
const BLOCK_SIZE_BYTES = 84u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx + 40u];
let dmin = src0[scale_idx + 41u];
let d = load_src0_f16_at(block_byte_base + 80u);
let dmin = load_src0_f16_at(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 55u;
const BLOCK_SIZE_BYTES = 110u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx + 54u];
let d = load_src0_f16_at(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
let scale_0 = src0[scale_idx + 48u + (2u*i)];
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
let hmask_0 = src0[scale_idx + (2u*i)];
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
let qs_0 = src0[scale_idx + 16u + (2u*i)];
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 72u;
const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// Map k_in_block to loop structure:
@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 88u;
const BLOCK_SIZE_BYTES = 176u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// The original loop processes elements in groups of 64
@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13_word = ql13_flat / 4u;
let ql13 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql13_word],
src0[scale_idx + 2u * ql13_word + 1u]
));
let ql13_b = get_byte(ql13, ql13_flat % 4u);
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24_word = ql24_flat / 4u;
let ql24 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql24_word],
src0[scale_idx + 2u * ql24_word + 1u]
));
let ql24_b = get_byte(ql24, ql24_flat % 4u);
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh_word = qh_flat / 4u;
let qh = bitcast<u32>(vec2(
src0[scale_idx + 64u + 2u * qh_word],
src0[scale_idx + 64u + 2u * qh_word + 1u]
));
let qh_b = get_byte(qh, qh_flat % 4u);
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc_word = sc_idx / 4u;
let sc = bitcast<u32>(vec2(
src0[scale_idx + 96u + 2u * sc_word],
src0[scale_idx + 96u + 2u * sc_word + 1u]
));
let sc_val = get_byte_i32(sc, sc_idx % 4u);
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = src0[scale_idx + 104u];
let d = load_src0_f16_at(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View File

@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q4_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 18u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 20u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = f32(src0[scale_idx + 1u]);
let d = f32(load_src0_f16_at(block_byte_base));
let m = f32(load_src0_f16_at(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 22u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = f32(load_src0_f16_at(block_byte_base));
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 24u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 34u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 36u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
let aligned = byte_offset & ~3u;
let idx = bbase + aligned / 2u;
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
}
const BLOCK_SIZE_BYTES = 210u;
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d_raw = load_u32_at(bbase, 208u);
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
let d = f32(load_src0_f16_at(bbase + 208u));
let ql1_u32 = load_u32_at(bbase, q_offset_l);
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View File

@ -1,138 +1,12 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f32_ff",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_ff_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f16_ff",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_ff_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(ROTATE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = {{TYPE}}(out0);
dst[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE)
#decl(ROTATE_INPLACE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = {{TYPE}}(out0);
src0[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE_INPLACE)
#decl(NO_FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#enddecl(NO_FF_FUNC)
#decl(FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#enddecl(FF_FUNC)
#decl(NO_FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS)
#decl(NO_FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS_INPLACE)
#decl(FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(FF_BINDINGS)
#decl(FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(FF_BINDINGS_INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
@ -168,12 +42,69 @@ struct Params {
};
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1: array<i32>;
DECLS
#ifdef INPLACE
#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<uniform> params: Params;
#endif
#else
#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(4)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
#ifdef FF_FUNC
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#else
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#endif
#ifdef INPLACE
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = DataType(out0);
src0[i_dst1] = DataType(out1);
}
#else
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = DataType(out0);
dst[i_dst1] = DataType(out1);
}
#endif
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
let y = (f32(i / 2) - low) / max(0.001f, high - low);
@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
var mscale = params.attn_factor;
var theta = params.freq_scale * theta_extrap;
var theta = params.freq_scale * theta_extrap;
if (params.ext_factor != 0.0f) {
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
@ -211,10 +142,9 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
}
}
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
// two elements per thread
// two elements per n_threads
if (gid.x >= params.n_threads) {
return;
}
@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let x0 = f32(src0[i_src]);
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
}
#end(SHADER)
}

View File

@ -1,215 +1,12 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "soft_max_f32",
"DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_inplace",
"DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink",
"DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink_inplace",
"DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(BASE_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS)
#decl(BASE_BINDINGS_INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS_INPLACE)
#decl(SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS)
#decl(SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS_INPLACE)
#decl(MASK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS)
#decl(MASK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS_INPLACE)
#decl(MASK_SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS)
#decl(MASK_SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS_INPLACE)
#decl(NOT_INPLACE)
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#enddecl(INPLACE)
#decl(NO_MASK)
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#enddecl(NO_MASK)
#decl(MASK)
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#enddecl(MASK)
#decl(NO_SINK)
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#enddecl(NO_SINK)
#decl(SINK)
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#enddecl(SINK)
#end(DECLS)
#define(SHADER)
enable f16;
#ifdef MASK_F32
#define MaskType f32
#endif
#ifdef MASK_F16
#define MaskType f16
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
@ -249,14 +46,117 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
#ifdef HAS_MASK
#ifdef HAS_SINK
@group(0) @binding(1)
var<storage, read_write> mask: array<MaskType>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
#ifdef INPLACE
@group(0) @binding(3)
var<uniform> params: Params;
#else
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#endif
#else
@group(0) @binding(1)
var<storage, read_write> mask: array<MaskType>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
#else
#ifdef HAS_SINK
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#else
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#endif
#endif
#endif
#ifdef INPLACE
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#else
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#endif
#ifdef HAS_MASK
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#else
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#endif
#ifdef HAS_SINK
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#else
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#endif
const CACHE_SIZE: u32 = 16;
var<workgroup> scratch: array<f32, WG_SIZE>;
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
let head = f32(i2);
let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
if (col < CACHE_SIZE) {
cache[col] = val;
}
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = max_val;
workgroupBarrier();
var offset = wg_size / 2;
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
} else {
update(i_dst_row + col, ex);
}
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
offset = wg_size / 2;
offset = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
break;
}
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
col += wg_size;
col += WG_SIZE;
}
}
#end(SHADER)

View File

@ -380,22 +380,33 @@ extern "C" {
size_t n_samplers;
};
struct llama_model_tensor_override {
const char * pattern;
enum ggml_type type;
};
struct llama_model_imatrix_data {
const char * name;
const float * data;
size_t size;
};
// model quantization parameters
typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
bool dry_run; // calculate and show the final quantization size without performing quantization
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
bool dry_run; // calculate and show the final quantization size without performing quantization
const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data
const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides
const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides
const int32_t * prune_layers; // pointer to layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {

View File

@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::map<i
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
@ -188,10 +187,9 @@ struct quantize_state_impl {
model(model), params(params)
{
// compile regex patterns once - they are expensive
if (params->tensor_types) {
const auto & tensor_types = *static_cast<const std::vector<tensor_type_option> *>(params->tensor_types);
for (const auto & [tname, qtype] : tensor_types) {
tensor_type_patterns.emplace_back(std::regex(tname), qtype);
if (params->tt_overrides) {
for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) {
tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type);
}
}
}
@ -857,12 +855,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
constexpr bool use_mmap = false;
#endif
llama_model_kv_override * kv_overrides = nullptr;
if (params->kv_overrides) {
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}
const llama_model_kv_override * kv_overrides = params->kv_overrides;
std::vector<std::string> splits = {};
llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr,
fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
@ -879,9 +872,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (params->only_copy) {
ftype = ml.ftype;
}
std::unordered_map<std::string, std::vector<float>> i_data;
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
if (params->imatrix) {
imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) {
i_data.emplace(p->name, std::vector<float>(p->data, p->data + p->size));
}
imatrix_data = & i_data;
if (imatrix_data) {
LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n",
__func__, (int)imatrix_data->size());
@ -902,7 +899,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
for (const int32_t * p = params->prune_layers; * p != -1; p++) {
prune_list.push_back(* p);
}
}
// copy the KV pairs from the input file
@ -916,20 +915,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
if (params->kv_overrides) {
const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
for (const auto & o : overrides) {
if (o.key[0] == 0) break;
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) {
if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64);
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
// Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64));
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64));
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool);
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out.get(), o->key, o->val_str);
} else {
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key);
}
}
}

View File

@ -13,13 +13,10 @@
#include <unordered_map>
#include <map>
#include <fstream>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <filesystem>
// result of parsing --tensor-type option
// (changes to this struct must be reflected in src/llama-quant.cpp)
// changes to this struct must also be reflected in src/llama-quant.cpp
struct tensor_type_option {
std::string name;
ggml_type type = GGML_TYPE_COUNT;
@ -491,7 +488,6 @@ static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
if (argc < 3) {
usage(argv[0]);
}
@ -584,8 +580,16 @@ int main(int argc, char ** argv) {
std::vector<std::string> imatrix_datasets;
std::unordered_map<std::string, std::vector<float>> imatrix_data;
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data);
std::vector<llama_model_imatrix_data> i_data;
std::vector<llama_model_tensor_override> t_override;
if (!imatrix_data.empty()) {
params.imatrix = &imatrix_data;
i_data.reserve(imatrix_data.size() + 1);
for (const auto & kv : imatrix_data) {
i_data.push_back({kv.first.c_str(), kv.second.data(), kv.second.size()});
}
i_data.push_back({nullptr, nullptr, 0}); // array terminator
params.imatrix = i_data.data();
{
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE);
@ -603,7 +607,6 @@ int main(int argc, char ** argv) {
kvo.val_str[127] = '\0';
kv_overrides.emplace_back(std::move(kvo));
}
{
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES);
@ -611,7 +614,6 @@ int main(int argc, char ** argv) {
kvo.val_i64 = imatrix_data.size();
kv_overrides.emplace_back(std::move(kvo));
}
if (m_last_call > 0) {
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS);
@ -623,13 +625,19 @@ int main(int argc, char ** argv) {
if (!kv_overrides.empty()) {
kv_overrides.emplace_back();
kv_overrides.back().key[0] = 0;
params.kv_overrides = &kv_overrides;
params.kv_overrides = kv_overrides.data();
}
if (!tensor_type_opts.empty()) {
params.tensor_types = &tensor_type_opts;
t_override.reserve(tensor_type_opts.size() + 1);
for (const auto & tt : tensor_type_opts) {
t_override.push_back({tt.name.c_str(), tt.type});
}
t_override.push_back({nullptr, GGML_TYPE_COUNT}); // array terminator
params.tt_overrides = t_override.data();
}
if (!prune_layers.empty()) {
params.prune_layers = &prune_layers;
prune_layers.push_back(-1); // array terminator
params.prune_layers = prune_layers.data();
}
llama_backend_init();

View File

@ -155,8 +155,8 @@ struct server_slot {
int64_t t_start_process_prompt;
int64_t t_start_generation;
double t_prompt_processing; // ms
double t_token_generation; // ms
double t_prompt_processing = 0.0; // ms
double t_token_generation = 0.0; // ms
std::function<void(int /* id_slot */)> callback_on_release;

View File

@ -261,14 +261,14 @@ struct result_timings {
int32_t cache_n = -1;
int32_t prompt_n = -1;
double prompt_ms;
double prompt_per_token_ms;
double prompt_per_second;
double prompt_ms = 0.0;
double prompt_per_token_ms = 0.0;
double prompt_per_second = 0.0;
int32_t predicted_n = -1;
double predicted_ms;
double predicted_per_token_ms;
double predicted_per_second;
double predicted_ms = 0.0;
double predicted_per_token_ms = 0.0;
double predicted_per_second = 0.0;
// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;