Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
neha-ha 2026-02-10 19:27:33 -08:00 committed by GitHub
parent 57487a64c8
commit b3927f807d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 993 additions and 1333 deletions

View File

@ -4,6 +4,7 @@
#include "ggml.h"
#include "pre_wgsl.hpp"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
@ -18,6 +19,46 @@
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
// Matrix multiplication parameters
// Register tiling parameters
#define WEBGPU_MUL_MAT_TILE_M 8
#define WEBGPU_MUL_MAT_TILE_N 8
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
#define WEBGPU_MUL_MAT_TILE_K 32
// Subgroup matrix parameters
// The number of subgroups in the M dimension
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
// The number of subgroups in the N dimension
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
// The number of subgroup matrices each subgroup accumulates over
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MAX_WG_SIZE 288
#define WEBGPU_MUL_MAT_WG_SIZE 256
// helper function for replacing {{PLACEHOLDERS}}
inline void ggml_webgpu_replace_placeholder(std::string & shader_code,
const std::string & key,
const std::string & value) {
std::string pattern = "{{" + key + "}}";
size_t pos = 0;
while ((pos = shader_code.find(pattern, pos)) != std::string::npos) {
shader_code.replace(pos, pattern.length(), value);
pos += value.length();
}
}
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
@ -178,7 +219,8 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
// Avoids having to use bounds-checks and decreasing performance for direct
// KV loads
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
@ -466,6 +508,22 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
return result;
}
/** Scale **/
struct ggml_webgpu_scale_pipeline_key {
int inplace;
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
};
struct ggml_webgpu_scale_pipeline_key_hash {
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
@ -490,6 +548,34 @@ struct ggml_webgpu_binary_pipeline_key_hash {
}
};
struct ggml_webgpu_scale_shader_lib_context {
ggml_webgpu_scale_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_scale_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_scale_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "scale";
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
struct ggml_webgpu_binary_shader_lib_context {
ggml_webgpu_binary_pipeline_key key;
uint32_t max_wg_size;
@ -535,4 +621,366 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
result.decisions = decisions;
return result;
}
/** get_rows */
struct ggml_webgpu_get_rows_pipeline_key {
ggml_type src_type;
int vectorized;
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
return src_type == other.src_type && vectorized == other.vectorized;
}
};
struct ggml_webgpu_get_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
struct ggml_webgpu_get_rows_shader_lib_context {
ggml_webgpu_get_rows_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_get_rows_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_get_rows_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "get_rows";
// Determine src type string
const char * type_str = nullptr;
// src type
const struct ggml_type_traits * type_traits = ggml_get_type_traits(context.key.src_type);
type_str = type_traits->type_name;
switch (context.key.src_type) {
case GGML_TYPE_F32:
if (context.key.vectorized) {
defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>");
defines.push_back("DST_TYPE=vec4<f32>");
defines.push_back("BLOCK_SIZE=4u");
} else {
defines.push_back("F32");
defines.push_back("SRC_TYPE=f32");
defines.push_back("DST_TYPE=f32");
defines.push_back("BLOCK_SIZE=1u");
}
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("F16");
defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32");
defines.push_back("BLOCK_SIZE=1u");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("I32");
defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32");
defines.push_back("BLOCK_SIZE=1u");
variant += "_i32";
break;
default:
// convert name to upper case for other defines
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
// push back defines for quantized types
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
// for q4_k and q5_k
defines.push_back(type_upper + "_SCALE_MIN");
// defines for i-quants
defines.push_back(type_upper + "_TABLES");
defines.push_back(type_upper + "_GRID");
// add variant
variant += "_";
variant += type_str;
// add define for quantized src0 type
defines.push_back(std::string("SRC_TYPE=") + type_str);
defines.push_back("DST_TYPE=f32");
break;
}
// determine block_size for quantized types
if (context.key.src_type == GGML_TYPE_I32) {
defines.push_back("BLOCK_SIZE=1u");
} else if ((context.key.src_type >= GGML_TYPE_Q4_0 && context.key.src_type <= GGML_TYPE_Q8_1) ||
context.key.src_type == GGML_TYPE_IQ4_NL) {
// Non-K quants use 32
defines.push_back("BLOCK_SIZE=32u");
} else if (context.key.src_type >= GGML_TYPE_Q2_K) {
// K-quants and IQ variants all use 256
defines.push_back("BLOCK_SIZE=256u");
}
// Vectorized suffix
if (context.key.vectorized) {
variant += "_vec";
}
defines.push_back("WORKGROUP_SIZE=" + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
// Create decisions structure to store workgroup size
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Matrix Multiplication **/
struct ggml_webgpu_mul_mat_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
int is_vec;
int use_subgroup_matrix;
int register_tile;
bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
is_vec == other.is_vec && use_subgroup_matrix == other.use_subgroup_matrix &&
register_tile == other.register_tile;
}
};
struct ggml_webgpu_mul_mat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.is_vec);
ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
ggml_webgpu_hash_combine(seed, key.register_tile);
return seed;
}
};
struct ggml_webgpu_mul_mat_shader_lib_context {
ggml_webgpu_mul_mat_pipeline_key key;
// For subgroup matrix paths
uint32_t max_subgroup_size;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
};
struct ggml_webgpu_mul_mat_shader_decisions {
uint32_t tile_k;
uint32_t wg_size_m;
uint32_t wg_size_n;
uint32_t wg_size;
uint32_t outputs_per_wg;
int is_vec;
int use_subgroup_matrix;
// Add new fields for all the parameters
uint32_t tile_m;
uint32_t tile_n;
// Subgroup matrix parameters
uint32_t subgroup_m;
uint32_t subgroup_n;
uint32_t subgroup_matrix_m;
uint32_t subgroup_matrix_n;
uint32_t mul_mat_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_mul_mat_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "mul_mat";
// Determine base variant name based on kernel type
if (context.key.is_vec) {
variant = "mul_mat_vec";
} else if (context.key.use_subgroup_matrix) {
variant = "mul_mat_subgroup_matrix";
} else if (context.key.register_tile) {
variant = "mul_mat_reg_tile";
}
// Determine src0/src1 type strings
const char * src0_type_str = nullptr;
bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile;
// src1 type
switch (context.key.src1_type) {
case GGML_TYPE_F32:
defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
break;
case GGML_TYPE_F16:
defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
break;
default:
break;
}
// same for all types
defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4<f16>" : "SHMEM_TYPE=f16");
// src0 type
const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type);
src0_type_str = src0_type_traits->type_name;
// for f32 and f16, account for vectorized src0 types
switch (context.key.src0_type) {
case GGML_TYPE_F32:
src0_type_str = context.key.vectorized ? "vec4<f32>" : "f32";
defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
defines.push_back("FLOAT");
defines.push_back("MUL_ACC_FLOAT");
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
src0_type_str = context.key.vectorized ? "vec4<f16>" : "f16";
defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
defines.push_back("FLOAT");
defines.push_back("MUL_ACC_FLOAT");
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
variant += "_f16";
break;
default:
// convert name to upper case for other defines
std::string type_upper = src0_type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
// push back defines for quantized types
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
// for q4_k and q5_k
defines.push_back(type_upper + "_SCALE_MIN");
// defines for i-quants
defines.push_back(type_upper + "_TABLES");
defines.push_back(type_upper + "_GRID");
// add variant
variant += "_";
variant += src0_type_str;
// add define for non-fast path quantized src0 type-- overwritten if fast
// path
defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
break;
}
// Add VEC/SCALAR defines
if (is_fast_path) {
// if quantized type and using fast path, need to use f16 instead of the
// quantized type
if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) {
src0_type_str = "f16";
defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
}
// all fast paths need VEC vs SCALAR
defines.push_back(context.key.vectorized ? "VEC" : "SCALAR");
// add vec_size define too
defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u");
// reg_tile and subgroup_matrix need these extra defines
if (!context.key.is_vec) {
defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR");
}
}
// Append src1 type
variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16");
// printf("DEBUG: After appending src1 type: variant='%s'\n",
// variant.c_str());
// Vectorized suffix
if (context.key.vectorized) {
variant += "_vec";
}
// Add defines for TILE_M and TILE_N
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
// Add subgroup matrix defines if using subgroup_matrix
if (context.key.use_subgroup_matrix) {
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
}
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K;
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
decisions->is_vec = context.key.is_vec;
decisions->use_subgroup_matrix = context.key.use_subgroup_matrix;
result.decisions = decisions;
return result;
}
#endif // GGML_WEBGPU_SHADER_LIB_HPP

View File

@ -69,21 +69,24 @@
/* Constants */
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
// implementations so this can be removed.
#define WEBGPU_MAX_WG_SIZE 288
#define WEBGPU_MUL_MAT_WG_SIZE 256
#define WEBGPU_NUM_PARAM_BUFS 16u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable default
// For operations which process a row in parallel, this seems like a reasonable
// default
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
// Matrix multiplication parameters
@ -106,13 +109,15 @@
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
// their locations.
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
// Always returns the base offset of a tensor, regardless of views.
@ -358,9 +363,12 @@ struct webgpu_context_struct {
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
@ -372,10 +380,13 @@ struct webgpu_context_struct {
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
set_rows_pipelines;
std::unordered_map<ggml_webgpu_get_rows_pipeline_key,
webgpu_pipeline,
ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines;
@ -383,7 +394,11 @@ struct webgpu_context_struct {
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
std::map<int, webgpu_pipeline> scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key,
webgpu_pipeline,
ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines;
@ -495,8 +510,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission_futures> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
// inflight_max may be 0, meaning that we must wait on all futures.
// If we have too many in-flight submissions, wait on the oldest one first. If
// there are many threads, inflight_max may be 0, meaning that we must wait on
// all futures.
uint64_t timeout_ms = block ? UINT64_MAX : 0;
uint32_t inflight_threads = ctx->inflight_threads;
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
@ -681,7 +697,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
}
// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
// If there are SET_ROWS operations in this submission, copy their error
// buffers to the host.
if (set_rows_error_bufs) {
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
set_rows_error_bufs->host_buf.GetSize());
@ -971,7 +988,8 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * dst) {
// For set rows specifically, we need to check if src and idx are empty tensors.
// For set rows specifically, we need to check if src and idx are empty
// tensors.
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
return std::nullopt;
}
@ -1054,25 +1072,68 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
error_bufs);
}
// Workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = wg_size;
return constants;
}
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * dst) {
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of dst
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
};
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
// Create pipeline key
ggml_webgpu_get_rows_pipeline_key key = { .src_type = src->type, .vectorized = (int) vectorized };
// Get or create pipeline
webgpu_pipeline pipeline;
auto it = ctx->get_rows_pipelines.find(key);
if (it != ctx->get_rows_pipelines.end()) {
pipeline = it->second;
} else {
// JIT compile the shader
const char * shader_src = wgsl_get_rows;
// Build shader context
ggml_webgpu_get_rows_shader_lib_context shader_lib_ctx = {
.key = key,
.max_wg_size = WEBGPU_MAX_WG_SIZE,
};
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_get_rows_shader(ctx->p, shader_src, shader_lib_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(shader_lib_ctx.max_wg_size);
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
processed.variant.c_str(), constants);
pipeline.context = processed.decisions;
ctx->get_rows_pipelines.emplace(key, pipeline);
}
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2]) };
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@ -1089,10 +1150,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -1100,45 +1159,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) dst->ne[0], // number of rows in result (M, transposed)
(uint32_t) dst->ne[1], // number of columns in result (N)
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
(uint32_t) src0->ne[2], // batch size in dimension 2
(uint32_t) src0->ne[3], // batch size in dimension 3
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
(uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
};
webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
uint32_t wg_y = 1;
// Determine if this is a mat-vec operation
bool is_vec = (dst->ne[1] == 1);
// Determine if we should use fast path
bool use_fast = false;
switch (src1->type) {
case GGML_TYPE_F16:
@ -1159,43 +1183,158 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
break;
}
int vectorized = 0;
if (use_fast) {
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
if (dst->ne[1] == 1) {
vectorized = (src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0);
if (is_vec) {
// We don't support vectorized mul_mat_vec for quantized types
vectorized = vectorized && (src0->type < 2);
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
uint32_t total_wg = output_groups * batches;
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
} else {
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
uint32_t wg_m;
uint32_t wg_n;
#ifndef __EMSCRIPTEN__
if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
// The total number of subgroups/workgroups needed per matrix.
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
ctx->global_ctx->capabilities.sg_mat_m;
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
ctx->global_ctx->capabilities.sg_mat_n;
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
} else {
#endif
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
#ifndef __EMSCRIPTEN__
}
#endif
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
vectorized = vectorized && (src0->type < 2);
}
}
// Create pipeline key
bool supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = src0->type,
.src1_type = src1->type,
.vectorized = use_fast ? vectorized : 0,
.is_vec = (use_fast && is_vec) ? 1 : 0,
.use_subgroup_matrix =
(use_fast && !is_vec && supports_subgroup_matrix) ? 1 : 0,
.register_tile =
(use_fast && !is_vec && !supports_subgroup_matrix) ? 1 : 0 };
// Build shader context
ggml_webgpu_mul_mat_shader_lib_context shader_lib_ctx = { .key = key,
.max_subgroup_size =
ctx->global_ctx->capabilities.max_subgroup_size,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k };
// Get or create pipeline
webgpu_pipeline pipeline;
const char * shader_src = nullptr;
auto it = ctx->mul_mat_pipelines.find(key);
if (it != ctx->mul_mat_pipelines.end()) {
pipeline = it->second;
} else {
// Select appropriate shader source based on key
if (!use_fast) {
// Use precompiled quantized shaders (mul_mat.tmpl.wgsl)
// These are the fallback for quantized types not supported by fast
// paths
shader_src = wgsl_mul_mat;
} else {
// Use JIT-compiled shader
if (is_vec) {
shader_src = wgsl_mul_mat_vec;
} else if (key.use_subgroup_matrix) {
shader_src = wgsl_mul_mat_subgroup_matrix;
} else {
shader_src = wgsl_mul_mat_reg_tile;
}
}
if (shader_src) {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_mul_mat_shader(ctx->p, shader_src, shader_lib_ctx);
std::vector<wgpu::ConstantEntry> constants;
if (shader_lib_ctx.key.is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
constants.push_back({ nullptr, "WORKGROUP_SIZE", static_cast<double>(decisions->wg_size) });
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
constants.push_back({ nullptr, "OUTPUTS_PER_WG", static_cast<double>(decisions->outputs_per_wg) });
} else if (shader_lib_ctx.key.register_tile) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
constants.push_back({ nullptr, "WORKGROUP_SIZE_M", static_cast<double>(decisions->wg_size_m) });
constants.push_back({ nullptr, "WORKGROUP_SIZE_N", static_cast<double>(decisions->wg_size_n) });
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
}
// printf("DEBUG: Creating pipeline with variant='%s', "
// "constants.size()=%zu\n",
// processed.variant.c_str(), constants.size());
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
processed.variant.c_str(), constants);
pipeline.context = processed.decisions;
ctx->mul_mat_pipelines.emplace(key, pipeline);
}
}
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Build params
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) src0->ne[0],
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src0->ne[2],
(uint32_t) src0->ne[3],
(uint32_t) (src1->ne[2] / src0->ne[2]),
(uint32_t) (src1->ne[3] / src0->ne[3])
};
// Build bind group entries
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
};
// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
if (decisions->is_vec) {
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
} else if (use_fast) {
// Fast-path tiled/subgroup calculations
uint32_t wg_m, wg_n;
if (decisions->use_subgroup_matrix) {
uint32_t wg_m_sg_tile =
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
uint32_t wg_n_sg_tile =
decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
} else {
uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
}
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
} else {
// Non-fast-path quantized shaders (Q2_K, Q4_K, etc.)
// Use the value from decisions instead of hardcoded constant
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->mul_mat_wg_size);
wg_y = 1;
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
@ -1671,6 +1810,29 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
ggml_webgpu_scale_pipeline_key key = { .inplace = inplace };
ggml_webgpu_scale_shader_lib_context shader_lib_ctx = {
.key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread
auto it = ctx->scale_pipelines.find(key);
if (it != ctx->scale_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_scale_shader(ctx->p, wgsl_scale, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->scale_pipelines.emplace(key, pipeline);
}
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
// params unchanged
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@ -1688,12 +1850,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
*(uint32_t *) &dst->op_params[1] // bias
};
// bindgroups unchanged
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src),
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
};
if (!inplace) {
entries.push_back({ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
@ -1702,8 +1866,7 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
entries, wg_x);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
@ -2233,7 +2396,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
size_t offset,
size_t size) {
if (size == 0) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
WEBGPU_LOG_DEBUG(
"ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
"nothing to do.");
return;
}
@ -2310,7 +2475,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
size_t final_size = size;
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
// If size is not a multiple of 4, we need to round it up to the next
// multiple of 4
final_size = size + (4 - (size % 4));
}
@ -2364,7 +2530,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
/* .clear = */ ggml_backend_webgpu_buffer_clear,
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
/* .reset = */ NULL, // TODO: optional, think it coordinates with
// .init_tensor
};
/* End GGML Backend Buffer Interface */
@ -2401,7 +2568,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_
return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
}
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
// maxBufferSize might be larger, but you can't bind more than
// maxStorageBufferBindingSize to a single binding.
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context * dev_ctx =
static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
@ -2487,14 +2655,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
}
// Workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = wg_size;
return constants;
}
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
@ -2509,207 +2669,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
}
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
// Q4/Q5/Q8 classic quantizations
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
// K-quantizations
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
// IQ quantizations (2-, 3-, 4-bit variants)
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
// 1-bit and 4-bit IQ variants
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
std::string proc_mul_mat_f32_f32;
std::string proc_mul_mat_f32_f32_vec;
std::string proc_mul_mat_f16_f32;
std::string proc_mul_mat_f16_f32_vec;
std::string proc_mul_mat_f16_f16;
std::string proc_mul_mat_f16_f16_vec;
std::string proc_mul_mat_q4_0_f32;
std::string proc_mul_mat_q4_0_f32_vec;
std::vector<wgpu::ConstantEntry> mul_mat_constants;
#ifndef __EMSCRIPTEN__
if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
std::map<std::string, std::string> sg_matrix_repls;
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
proc_mul_mat_f32_f32_vec =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
proc_mul_mat_f16_f32_vec =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
proc_mul_mat_f16_f16_vec =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
proc_mul_mat_q4_0_f32 =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
proc_mul_mat_q4_0_f32_vec =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
} else {
#endif
mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
std::map<std::string, std::string> reg_repls;
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
#ifndef __EMSCRIPTEN__
}
#endif
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
mul_mat_vec_constants[1].key = "TILE_K";
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
@ -2816,15 +2775,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->scale_pipelines[0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
"scale_f32_inplace", constants);
}
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
@ -3023,13 +2973,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_scale_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
@ -3071,11 +3018,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
/* .is_host = */ NULL, // defaults to false
/* .alloc_buffer = */
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,

View File

@ -1,5 +1,4 @@
#decl(BYTE_HELPERS)
#ifdef BYTE_HELPERS
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#endif
#enddecl(BYTE_HELPERS)
#decl(Q4_0_T)
#ifdef Q4_0_T
struct q4_0 {
d: f16,
qs: array<f16, 8>
};
#enddecl(Q4_0_T)
#endif
#decl(Q4_1_T)
#ifdef Q4_1_T
struct q4_1 {
d: f16,
m: f16,
qs: array<u32, 4>
};
#enddecl(Q4_1_T)
#endif
#decl(Q5_0_T)
#ifdef Q5_0_T
struct q5_0 {
d: f16,
qh: array<f16, 2>,
qs: array<f16, 8>
};
#enddecl(Q5_0_T)
#endif
#decl(Q5_1_T)
#ifdef Q5_1_T
struct q5_1 {
d: f16,
m: f16,
qh: u32,
qs: array<u32, 4>
};
#enddecl(Q5_1_T)
#endif
#decl(Q8_0_T)
#ifdef Q8_0_T
struct q8_0 {
d: f16,
qs: array<f16, 16>
};
#enddecl(Q8_0_T)
#endif
#decl(Q8_1_T)
#ifdef Q8_1_T
struct q8_1 {
d: f16,
m: f16,
qs: array<u32, 8>
};
#enddecl(Q8_1_T)
#endif
#decl(Q2_K_T)
struct q2_k {
#ifdef Q2_K_T
struct q2_K {
scales: array<u32, 4>,
qs: array<u32, 16>,
d: f16,
dmin: f16
};
#enddecl(Q2_K_T)
#endif
#decl(Q3_K_T)
struct q3_k {
#ifdef Q3_K_T
struct q3_K {
hmask: array<f16, 16>,
qs: array<f16, 32>,
scales: array<f16, 6>,
d: f16
};
#enddecl(Q3_K_T)
#decl(Q45_K_SCALE_MIN)
#endif
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
if (is < 4) {
let sc_byte = get_byte(scales[is / 4], is % 4);
@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
return vec2(f32(sc), f32(m));
}
}
#enddecl(Q45_K_SCALE_MIN)
#decl(Q4_K_T)
struct q4_k {
#endif
#ifdef Q4_K_T
struct q4_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qs: array<u32, 32>
};
#enddecl(Q4_K_T)
#endif
#decl(Q5_K_T)
struct q5_k {
#ifdef Q5_K_T
struct q5_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qh: array<u32, 8>,
qs: array<u32, 32>
};
#enddecl(Q5_K_T)
#endif
#decl(Q6_K_T)
struct q6_k {
#ifdef Q6_K_T
struct q6_K {
ql: array<f16, 64>,
qh: array<f16, 32>,
scales: array<f16, 8>,
d: f16
};
#enddecl(Q6_K_T)
#endif
#decl(IQ2_XXS_T)
#ifdef IQ2_XXS_T
struct iq2_xxs {
d: f16,
qs: array<f16, 32>
};
#enddecl(IQ2_XXS_T)
#endif
#decl(IQ2_XS_T)
#ifdef IQ2_XS_T
struct iq2_xs {
d: f16,
qs: array<f16, 32>,
scales: array<f16, 4>
};
#enddecl(IQ2_XS_T)
#endif
#decl(IQ2_S_T)
#ifdef IQ2_S_T
struct iq2_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
scales: array<f16, 4>
};
#enddecl(IQ2_S_T)
#endif
#decl(IQ3_XSS_T)
#ifdef IQ3_XXS_T
struct iq3_xxs {
d: f16,
qs: array<f16, 48>
};
#enddecl(IQ3_XSS_T)
#endif
#decl(IQ3_S_T)
#ifdef IQ3_S_T
struct iq3_s {
d: f16,
qs: array<f16, 32>,
@ -161,41 +156,41 @@ struct iq3_s {
signs: array<f16, 16>,
scales: array<f16, 2>
};
#enddecl(IQ3_S_T)
#endif
#decl(IQ1_S_T)
#ifdef IQ1_S_T
struct iq1_s {
d: f16,
qs: array<f16, 16>,
qh: array<f16, 8>
};
#enddecl(IQ1_S_T)
#endif
#decl(IQ1_M_T)
#ifdef IQ1_M_T
struct iq1_m {
qs: array<u32, 8>,
qh: array<u32, 4>,
scales: array<u32, 2>
};
#enddecl(IQ1_M_T)
#endif
#decl(IQ4_NL_T)
#ifdef IQ4_NL_T
struct iq4_nl {
d: f16,
qs: array<f16, 8>,
};
#enddecl(IQ4_NL_T)
#endif
#decl(IQ4_XS_T)
#ifdef IQ4_XS_T
struct iq4_xs {
d: f16,
scales_h: f16,
scales_l: u32,
qs: array<u32, 32>
};
#enddecl(IQ4_XS_T)
#endif
#decl(IQ23_TABLES)
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
0x08040201u, // 1, 2, 4, 8
0x80402010u // 16, 32, 64, 128
@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
);
#enddecl(IQ23_TABLES)
#endif
#decl(IQ2_XXS_GRID)
#ifdef IQ2_XXS_GRID
const iq2xxs_grid = array<u32, 512>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
);
#enddecl(IQ2_XXS_GRID)
#endif
#decl(IQ2_XS_GRID)
#ifdef IQ2_XS_GRID
const iq2xs_grid = array<u32, 1024>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_XS_GRID)
#endif
#decl(IQ2_S_GRID)
#ifdef IQ2_S_GRID
const iq2s_grid = array<u32, 2048>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_S_GRID)
#decl(IQ3_XSS_GRID)
#endif
#ifdef IQ3_XXS_GRID
const iq3xxs_grid = array<u32, 256>(
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
);
#enddecl(IQ3_XSS_GRID)
#decl(IQ3_S_GRID)
#endif
#ifdef IQ3_S_GRID
const iq3s_grid = array<u32, 512>(
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
);
#enddecl(IQ3_S_GRID)
#endif
#decl(IQ1_GRID)
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
const IQ1_DELTA: f32 = 0.125;
@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
);
#enddecl(IQ1_GRID)
#endif
#decl(IQ4_GRID)
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
const kvalues_iq4nl = array<i32, 16>(
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
);
#enddecl(IQ4_GRID)
#endif

View File

@ -56,7 +56,9 @@ def expand_includes(shader, input_dir):
return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile):
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
shader_code = expand_includes(shader_code, input_dir)
if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
@ -74,7 +76,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
@ -123,7 +125,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():

View File

@ -1,222 +1,31 @@
#define(VARIANTS)
enable f16;
#include "common_decls.tmpl"
[
{
"SHADER_SUFFIX": "f32_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"BLOCK_SIZE": 4
},
"DECLS": ["F32_VEC"]
},
{
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F32"]
},
{
"REPLS": {
"TYPE" : "f16",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F16"]
},
{
"REPLS": {
"TYPE" : "i32",
"DST_TYPE": "i32",
"BLOCK_SIZE": 1
},
"DECLS": ["I32"]
},
{
"REPLS": {
"TYPE" : "q4_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"TYPE" : "q4_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"TYPE" : "q5_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"TYPE" : "q5_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"TYPE" : "q8_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"TYPE" : "q2_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"TYPE" : "q3_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"TYPE" : "q4_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"TYPE" : "q5_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"TYPE" : "q6_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"TYPE" : "iq2_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"TYPE" : "iq2_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"TYPE": "iq2_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"TYPE": "iq3_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"TYPE": "iq3_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"TYPE": "iq1_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"TYPE": "iq1_m",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"TYPE": "iq4_nl",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"TYPE": "iq4_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(F32_VEC)
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
}
#enddecl(F32_VEC)
#endif
#decl(F32)
#ifdef F32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(F32)
#endif
#decl(F16)
#ifdef F16
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = f32(src[src_base + offset]);
}
#enddecl(F16)
#endif
#decl(I32)
#ifdef I32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(I32)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_0 = src[src_base + offset];
let d = f32(block_q4_0.d);
@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_1 = src[src_base + offset];
let d = f32(block_q4_1.d);
@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_0 = src[src_base + offset];
let d = f32(block_q5_0.d);
@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(Q5_0)
#decl(Q5_1)
#ifdef Q5_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_1 = src[src_base + offset];
let d = f32(block_q5_1.d);
@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q8_0 = src[src_base + offset];
let d = f32(block_q8_0.d);
@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q8_0)
#endif
#decl(Q2_K)
#ifdef Q2_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q2_K)
#endif
#decl(Q3_K)
#ifdef Q3_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q3_K)
#endif
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_K)
#endif
#decl(Q5_K)
#ifdef Q5_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_K)
#endif
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
sc_b_idx += 8;
}
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XXS)
#endif
#decl(IQ2_XS)
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XS)
#endif
#decl(IQ2_S)
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_XSS)
#endif
#decl(IQ3_S)
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i++;
}
}
#enddecl(IQ4_NL)
#endif
#decl(IQ4_XS)
#ifdef IQ4_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i += 16;
}
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> idx: array<i32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
offset_src: u32, // in elements
@ -866,9 +662,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
}
#end(SHADER)

View File

@ -1,195 +1,24 @@
#define(VARIANTS)
enable f16;
[
{
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q8_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q2_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q3_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q6_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_m",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_nl",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#include "common_decls.tmpl"
#end(VARIANTS)
#ifdef FLOAT
const BLOCK_SIZE = 1u;
#define(DECLS)
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;
#decl(FLOAT)
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif
#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
}
#enddecl(FLOAT)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_0 = src0[src0_idx_base + offset];
let d = f32(block_q4_0.d);
@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_1 = src0[src0_idx_base + offset];
let d = f32(block_q4_1.d);
@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_0 = src0[src0_idx_base + offset];
let d = f32(block_q5_0.d);
@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_0)
#endif
#decl(Q5_1)
#ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_1 = src0[src0_idx_base + offset];
let d = f32(block_q5_1.d);
@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_0 = src0[src0_idx_base + offset];
let d = f32(block_q8_0.d);
@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_0)
#endif
#decl(Q8_1)
#ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_1 = src0[src0_idx_base + offset];
let d = f32(block_q8_1.d);
@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_1)
#endif
#decl(Q2_K)
#ifdef Q2_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q2_K)
#decl(Q3_K)
#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q3_K)
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q4_K)
#decl(Q5_K)
#ifdef Q5_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q5_K)
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XXS)
#decl(IQ2_XS)
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XS)
#decl(IQ2_S)
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ3_XSS)
#decl(IQ3_S)
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ4_NL)
#decl(IQ4_XS)
#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
struct MulMatParams {
offset_src0: u32, // in elements/blocks
@ -864,8 +672,8 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
var sum = 0.0;
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i);
}
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}
#end(SHADER)

View File

@ -1,58 +1,53 @@
#decl(SHMEM_VEC)
#ifdef SHMEM_VEC
fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#enddecl(SHMEM_VEC)
#endif
#decl(SHMEM_SCALAR)
#ifdef SHMEM_SCALAR
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#enddecl(SHMEM_SCALAR)
#decl(INIT_SRC0_SHMEM_FLOAT)
#endif
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let src0_val = select( // taking a slight performance hit to avoid oob
{{SRC0_TYPE}}(0.0),
src0[src0_idx/{{VEC_SIZE}}],
SRC0_TYPE(0.0),
src0[src0_idx/VEC_SIZE],
global_m < params.m && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
#endif
#enddecl(INIT_SRC0_SHMEM_FLOAT)
#decl(INIT_SRC1_SHMEM)
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_n = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_n = offset_n + tile_n;
let global_k = k_outer + tile_k;
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
let src1_val = select(
{{SRC1_TYPE}}(0.0),
src1[src1_idx/{{VEC_SIZE}}],
SRC1_TYPE(0.0),
src1[src1_idx/VEC_SIZE],
global_n < params.n && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
#endif
#enddecl(INIT_SRC1_SHMEM)
#decl(INIT_SRC0_SHMEM_Q4_0)
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
@ -93,5 +88,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
}
#enddecl(INIT_SRC0_SHMEM_Q4_0)
#endif

View File

@ -1,115 +1,19 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
enable f16;
#end(VARIANTS)
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS)
#decl(VEC)
#ifdef VEC
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
}
#enddecl(VEC)
#endif
#decl(SCALAR)
#ifdef SCALAR
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return f32(acc[tm][tn]);
}
#enddecl(SCALAR)
#end(DECLS)
#define(SHADER)
enable f16;
#endif
struct MulMatParams {
offset_src0: u32,
@ -130,14 +34,13 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
fn get_local_n(thread_id: u32) -> u32 {
return thread_id / WORKGROUP_SIZE_M;
}
@ -145,10 +48,6 @@ fn get_local_m(thread_id: u32) -> u32 {
return thread_id % WORKGROUP_SIZE_M;
}
// TILE_M must be multiple of 4 for vec4 loads
const TILE_M = {{WEBGPU_TILE_M}}u;
const TILE_N = {{WEBGPU_TILE_N}}u;
override WORKGROUP_SIZE_M: u32;
override WORKGROUP_SIZE_N: u32;
override TILE_K: u32;
@ -233,15 +132,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var tn = 0u; tn < TILE_N; tn++) {
let global_col = output_col_base + tn;
if (global_col < params.n) {
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
let global_row = output_row_base + tm;
if (global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
}
}
}
}
}
#end(SHADER)

View File

@ -1,100 +1,12 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
diagnostic(off, chromium.subgroup_matrix_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#end(VARIANTS)
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS)
#decl(VEC)
#ifdef VEC
fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = vec4<f32>(
f32(shmem[shmem_idx]),
@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
f32(shmem[shmem_idx + 3])
);
}
#enddecl(VEC)
#endif
#decl(SCALAR)
#ifdef SCALAR
fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = f32(shmem[shmem_idx]);
}
#enddecl(SCALAR)
#end(DECLS)
#define(SHADER)
diagnostic(off, chromium.subgroup_matrix_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#endif
struct MulMatParams {
offset_src0: u32,
@ -138,36 +42,19 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
// runtime subgroup size is smaller.
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
const TILE_K = {{WEBGPU_TILE_K}}u;
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
// runtime subgroup size is smaller.
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let local_row = idx % WG_TILE_STRIDE;
let local_col = idx / WG_TILE_STRIDE;
@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_col < params.n && global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
store_dst(idx, dst_idx/{{VEC_SIZE}});
store_dst(idx, dst_idx/VEC_SIZE);
}
}
}
#end(SHADER)

View File

@ -1,84 +1,11 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
}
]
#end(VARIANTS)
enable f16;
#define(DECLS)
#include "common_decls.tmpl"
#decl(VEC)
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
#ifdef VEC
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(dot(SRC1_TYPE(src0_val), src1_val));
}
fn store_val(group_base: u32) -> vec4<f32> {
@ -87,33 +14,31 @@ fn store_val(group_base: u32) -> vec4<f32> {
partial_sums[group_base + THREADS_PER_OUTPUT * 2],
partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
}
#enddecl(VEC)
#endif
#decl(SCALAR)
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
#ifdef SCALAR
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(src0_val) * f32(src1_val);
}
fn store_val(group_base: u32) -> f32 {
return partial_sums[group_base];
}
#enddecl(SCALAR)
#decl(MUL_ACC_FLOAT)
#endif
#ifdef MUL_ACC_FLOAT
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
let b = shared_vector[i / {{VEC_SIZE}}];
for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
let b = shared_vector[i / VEC_SIZE];
local_sum += inner_dot(a, b);
}
return local_sum;
}
#endif
#enddecl(MUL_ACC_FLOAT)
#decl(MUL_ACC_Q4_0)
#ifdef MUL_ACC_Q4_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
@ -145,15 +70,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
}
return local_sum;
}
#enddecl(MUL_ACC_Q4_0)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
struct MulMatParams {
offset_src0: u32,
@ -174,9 +91,10 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
@ -186,7 +104,7 @@ override OUTPUTS_PER_WG: u32;
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
// Shared memory for collaborative loading and reduction
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
@compute @workgroup_size(WORKGROUP_SIZE)
@ -232,8 +150,8 @@ fn main(
let tile_size = min(TILE_K, params.k - k_tile);
// Cooperatively load vector tile into shared memory (all threads)
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WORKGROUP_SIZE * VEC_SIZE) {
shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
}
workgroupBarrier();
@ -260,8 +178,8 @@ fn main(
}
// Store back to global memory
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
dst[dst_idx / VEC_SIZE] = store_val(group_base);
}
}
#end(SHADER)

View File

@ -1,21 +1,11 @@
#define(VARIANTS)
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
[
{
"SHADER_NAME": "scale_f32",
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "scale_f32_inplace",
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
fn store_scale(val: f32, offset: u32) {
src[offset] = val;
}
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@ -25,20 +15,7 @@ var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
dst[offset] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
src[offset] = val;
}
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
#endif
struct Params {
offset_src: u32,
@ -65,10 +42,7 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
store_scale(src[i_src] * params.scale + params.bias, i_dst);
}
#end(SHADER)