Basic JIT compilation for mul_mat, get_rows, and scale (#17)
* scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
parent
57487a64c8
commit
b3927f807d
|
|
@ -4,6 +4,7 @@
|
|||
#include "ggml.h"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -18,6 +19,46 @@
|
|||
|
||||
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
|
||||
|
||||
// Matrix multiplication parameters
|
||||
|
||||
// Register tiling parameters
|
||||
#define WEBGPU_MUL_MAT_TILE_M 8
|
||||
#define WEBGPU_MUL_MAT_TILE_N 8
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
||||
#define WEBGPU_MUL_MAT_TILE_K 32
|
||||
|
||||
// Subgroup matrix parameters
|
||||
// The number of subgroups in the M dimension
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
||||
// The number of subgroups in the N dimension
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
|
||||
// The number of subgroup matrices each subgroup accumulates over
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide
|
||||
// mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||
|
||||
#define WEBGPU_MAX_WG_SIZE 288
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
|
||||
// helper function for replacing {{PLACEHOLDERS}}
|
||||
inline void ggml_webgpu_replace_placeholder(std::string & shader_code,
|
||||
const std::string & key,
|
||||
const std::string & value) {
|
||||
std::string pattern = "{{" + key + "}}";
|
||||
size_t pos = 0;
|
||||
while ((pos = shader_code.find(pattern, pos)) != std::string::npos) {
|
||||
shader_code.replace(pos, pattern.length(), value);
|
||||
pos += value.length();
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
|
|
@ -178,7 +219,8 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (context.key.kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
|
||||
// Avoids having to use bounds-checks and decreasing performance for direct
|
||||
// KV loads
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
|
|
@ -466,6 +508,22 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
|
|||
return result;
|
||||
}
|
||||
|
||||
/** Scale **/
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key {
|
||||
int inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Binary **/
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key {
|
||||
|
|
@ -490,6 +548,34 @@ struct ggml_webgpu_binary_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_scale_shader_lib_context {
|
||||
ggml_webgpu_scale_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_scale_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_scale_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "scale";
|
||||
|
||||
if (context.key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_binary_shader_lib_context {
|
||||
ggml_webgpu_binary_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
|
|
@ -535,4 +621,366 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
|
|||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** get_rows */
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key {
|
||||
ggml_type src_type;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
|
||||
return src_type == other.src_type && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_get_rows_shader_lib_context {
|
||||
ggml_webgpu_get_rows_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_get_rows_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_get_rows_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "get_rows";
|
||||
|
||||
// Determine src type string
|
||||
const char * type_str = nullptr;
|
||||
|
||||
// src type
|
||||
const struct ggml_type_traits * type_traits = ggml_get_type_traits(context.key.src_type);
|
||||
type_str = type_traits->type_name;
|
||||
|
||||
switch (context.key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
if (context.key.vectorized) {
|
||||
defines.push_back("F32_VEC");
|
||||
defines.push_back("SRC_TYPE=vec4<f32>");
|
||||
defines.push_back("DST_TYPE=vec4<f32>");
|
||||
defines.push_back("BLOCK_SIZE=4u");
|
||||
} else {
|
||||
defines.push_back("F32");
|
||||
defines.push_back("SRC_TYPE=f32");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
}
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("F16");
|
||||
defines.push_back("SRC_TYPE=f16");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("I32");
|
||||
defines.push_back("SRC_TYPE=i32");
|
||||
defines.push_back("DST_TYPE=i32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
// convert name to upper case for other defines
|
||||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
// push back defines for quantized types
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
|
||||
// for q4_k and q5_k
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
|
||||
// defines for i-quants
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
|
||||
// add variant
|
||||
variant += "_";
|
||||
variant += type_str;
|
||||
|
||||
// add define for quantized src0 type
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
break;
|
||||
}
|
||||
|
||||
// determine block_size for quantized types
|
||||
if (context.key.src_type == GGML_TYPE_I32) {
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
} else if ((context.key.src_type >= GGML_TYPE_Q4_0 && context.key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
context.key.src_type == GGML_TYPE_IQ4_NL) {
|
||||
// Non-K quants use 32
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (context.key.src_type >= GGML_TYPE_Q2_K) {
|
||||
// K-quants and IQ variants all use 256
|
||||
defines.push_back("BLOCK_SIZE=256u");
|
||||
}
|
||||
|
||||
// Vectorized suffix
|
||||
if (context.key.vectorized) {
|
||||
variant += "_vec";
|
||||
}
|
||||
|
||||
defines.push_back("WORKGROUP_SIZE=" + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
|
||||
// Create decisions structure to store workgroup size
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
|
||||
struct ggml_webgpu_mul_mat_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
int is_vec;
|
||||
int use_subgroup_matrix;
|
||||
int register_tile;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
is_vec == other.is_vec && use_subgroup_matrix == other.use_subgroup_matrix &&
|
||||
register_tile == other.register_tile;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.is_vec);
|
||||
ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
|
||||
ggml_webgpu_hash_combine(seed, key.register_tile);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_shader_lib_context {
|
||||
ggml_webgpu_mul_mat_pipeline_key key;
|
||||
|
||||
// For subgroup matrix paths
|
||||
uint32_t max_subgroup_size;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_shader_decisions {
|
||||
uint32_t tile_k;
|
||||
uint32_t wg_size_m;
|
||||
uint32_t wg_size_n;
|
||||
uint32_t wg_size;
|
||||
uint32_t outputs_per_wg;
|
||||
int is_vec;
|
||||
int use_subgroup_matrix;
|
||||
|
||||
// Add new fields for all the parameters
|
||||
uint32_t tile_m;
|
||||
uint32_t tile_n;
|
||||
|
||||
// Subgroup matrix parameters
|
||||
uint32_t subgroup_m;
|
||||
uint32_t subgroup_n;
|
||||
uint32_t subgroup_matrix_m;
|
||||
uint32_t subgroup_matrix_n;
|
||||
|
||||
uint32_t mul_mat_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_mul_mat_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat";
|
||||
|
||||
// Determine base variant name based on kernel type
|
||||
if (context.key.is_vec) {
|
||||
variant = "mul_mat_vec";
|
||||
} else if (context.key.use_subgroup_matrix) {
|
||||
variant = "mul_mat_subgroup_matrix";
|
||||
} else if (context.key.register_tile) {
|
||||
variant = "mul_mat_reg_tile";
|
||||
}
|
||||
|
||||
// Determine src0/src1 type strings
|
||||
const char * src0_type_str = nullptr;
|
||||
|
||||
bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile;
|
||||
|
||||
// src1 type
|
||||
switch (context.key.src1_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
|
||||
defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
|
||||
defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
// same for all types
|
||||
defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4<f16>" : "SHMEM_TYPE=f16");
|
||||
|
||||
// src0 type
|
||||
const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type);
|
||||
src0_type_str = src0_type_traits->type_name;
|
||||
|
||||
// for f32 and f16, account for vectorized src0 types
|
||||
switch (context.key.src0_type) {
|
||||
case GGML_TYPE_F32:
|
||||
src0_type_str = context.key.vectorized ? "vec4<f32>" : "f32";
|
||||
defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
|
||||
|
||||
defines.push_back("FLOAT");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
|
||||
variant += "_f32";
|
||||
break;
|
||||
|
||||
case GGML_TYPE_F16:
|
||||
src0_type_str = context.key.vectorized ? "vec4<f16>" : "f16";
|
||||
defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
|
||||
|
||||
defines.push_back("FLOAT");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
|
||||
variant += "_f16";
|
||||
|
||||
break;
|
||||
|
||||
default:
|
||||
// convert name to upper case for other defines
|
||||
std::string type_upper = src0_type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
// push back defines for quantized types
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
|
||||
defines.push_back(type_upper);
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
|
||||
// for q4_k and q5_k
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
|
||||
// defines for i-quants
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
|
||||
// add variant
|
||||
variant += "_";
|
||||
variant += src0_type_str;
|
||||
|
||||
// add define for non-fast path quantized src0 type-- overwritten if fast
|
||||
// path
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// Add VEC/SCALAR defines
|
||||
if (is_fast_path) {
|
||||
// if quantized type and using fast path, need to use f16 instead of the
|
||||
// quantized type
|
||||
if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) {
|
||||
src0_type_str = "f16";
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
|
||||
}
|
||||
|
||||
// all fast paths need VEC vs SCALAR
|
||||
defines.push_back(context.key.vectorized ? "VEC" : "SCALAR");
|
||||
// add vec_size define too
|
||||
defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u");
|
||||
|
||||
// reg_tile and subgroup_matrix need these extra defines
|
||||
if (!context.key.is_vec) {
|
||||
defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR");
|
||||
}
|
||||
}
|
||||
|
||||
// Append src1 type
|
||||
variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16");
|
||||
|
||||
// printf("DEBUG: After appending src1 type: variant='%s'\n",
|
||||
// variant.c_str());
|
||||
|
||||
// Vectorized suffix
|
||||
if (context.key.vectorized) {
|
||||
variant += "_vec";
|
||||
}
|
||||
|
||||
// Add defines for TILE_M and TILE_N
|
||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
||||
|
||||
// Add subgroup matrix defines if using subgroup_matrix
|
||||
if (context.key.use_subgroup_matrix) {
|
||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
||||
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
||||
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
||||
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
||||
defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
|
||||
defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
|
||||
defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
|
||||
defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
|
||||
defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
||||
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
||||
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
||||
decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K;
|
||||
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
|
||||
decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
|
||||
decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
|
||||
decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
|
||||
decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
||||
decisions->is_vec = context.key.is_vec;
|
||||
decisions->use_subgroup_matrix = context.key.use_subgroup_matrix;
|
||||
result.decisions = decisions;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
|
|
|||
|
|
@ -69,21 +69,24 @@
|
|||
|
||||
/* Constants */
|
||||
|
||||
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
|
||||
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
|
||||
// implementations so this can be removed.
|
||||
#define WEBGPU_MAX_WG_SIZE 288
|
||||
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
#define WEBGPU_NUM_PARAM_BUFS 16u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
|
||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
||||
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
|
||||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
// For operations which process a row in parallel, this seems like a reasonable default
|
||||
// For operations which process a row in parallel, this seems like a reasonable
|
||||
// default
|
||||
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
||||
|
||||
// Matrix multiplication parameters
|
||||
|
|
@ -106,13 +109,15 @@
|
|||
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide
|
||||
// mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||
|
||||
/* End Constants */
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
|
||||
// Always returns the base offset of a tensor, regardless of views.
|
||||
|
|
@ -358,9 +363,12 @@ struct webgpu_context_struct {
|
|||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
|
|
@ -372,10 +380,13 @@ struct webgpu_context_struct {
|
|||
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
set_rows_pipelines;
|
||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_get_rows_pipeline_key_hash>
|
||||
get_rows_pipelines; // src_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines;
|
||||
|
|
@ -383,7 +394,11 @@ struct webgpu_context_struct {
|
|||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
|
||||
std::map<int, webgpu_pipeline> scale_pipelines; // inplace
|
||||
|
||||
std::unordered_map<ggml_webgpu_scale_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_scale_pipeline_key_hash>
|
||||
scale_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
|
||||
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
||||
unary_pipelines;
|
||||
|
|
@ -495,8 +510,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_submission_futures> & futures,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
|
||||
// inflight_max may be 0, meaning that we must wait on all futures.
|
||||
// If we have too many in-flight submissions, wait on the oldest one first. If
|
||||
// there are many threads, inflight_max may be 0, meaning that we must wait on
|
||||
// all futures.
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
uint32_t inflight_threads = ctx->inflight_threads;
|
||||
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
||||
|
|
@ -681,7 +697,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
|||
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
|
||||
}
|
||||
|
||||
// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
|
||||
// If there are SET_ROWS operations in this submission, copy their error
|
||||
// buffers to the host.
|
||||
if (set_rows_error_bufs) {
|
||||
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
|
||||
set_rows_error_bufs->host_buf.GetSize());
|
||||
|
|
@ -971,7 +988,8 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
// For set rows specifically, we need to check if src and idx are empty tensors.
|
||||
// For set rows specifically, we need to check if src and idx are empty
|
||||
// tensors.
|
||||
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
|
@ -1054,25 +1072,68 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
error_bufs);
|
||||
}
|
||||
|
||||
// Workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = wg_size;
|
||||
return constants;
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Shape of dst
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
|
||||
// Shape of idx
|
||||
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
||||
};
|
||||
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
|
||||
|
||||
// Create pipeline key
|
||||
ggml_webgpu_get_rows_pipeline_key key = { .src_type = src->type, .vectorized = (int) vectorized };
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->get_rows_pipelines.find(key);
|
||||
if (it != ctx->get_rows_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
// JIT compile the shader
|
||||
const char * shader_src = wgsl_get_rows;
|
||||
|
||||
// Build shader context
|
||||
ggml_webgpu_get_rows_shader_lib_context shader_lib_ctx = {
|
||||
.key = key,
|
||||
.max_wg_size = WEBGPU_MAX_WG_SIZE,
|
||||
};
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_get_rows_shader(ctx->p, shader_src, shader_lib_ctx);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(shader_lib_ctx.max_wg_size);
|
||||
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
|
||||
processed.variant.c_str(), constants);
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->get_rows_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
(uint32_t) (idx->ne[1]),
|
||||
(uint32_t) (idx->ne[2]) };
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
|
|
@ -1089,10 +1150,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
|
||||
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
|
||||
|
||||
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
|
||||
webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
|
|
@ -1100,45 +1159,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) dst->ne[0], // number of rows in result (M, transposed)
|
||||
(uint32_t) dst->ne[1], // number of columns in result (N)
|
||||
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
|
||||
(uint32_t) src0->ne[2], // batch size in dimension 2
|
||||
(uint32_t) src0->ne[3], // batch size in dimension 3
|
||||
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
|
||||
(uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
|
||||
uint32_t wg_y = 1;
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
|
||||
// Determine if we should use fast path
|
||||
bool use_fast = false;
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
|
|
@ -1159,43 +1183,158 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
break;
|
||||
}
|
||||
|
||||
int vectorized = 0;
|
||||
if (use_fast) {
|
||||
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
|
||||
if (dst->ne[1] == 1) {
|
||||
vectorized = (src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0);
|
||||
if (is_vec) {
|
||||
// We don't support vectorized mul_mat_vec for quantized types
|
||||
vectorized = vectorized && (src0->type < 2);
|
||||
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
} else {
|
||||
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
// The total number of subgroups/workgroups needed per matrix.
|
||||
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
|
||||
ctx->global_ctx->capabilities.sg_mat_m;
|
||||
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
||||
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
|
||||
ctx->global_ctx->capabilities.sg_mat_n;
|
||||
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
||||
} else {
|
||||
#endif
|
||||
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
||||
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
||||
#ifndef __EMSCRIPTEN__
|
||||
}
|
||||
#endif
|
||||
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
vectorized = vectorized && (src0->type < 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Create pipeline key
|
||||
bool supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = src0->type,
|
||||
.src1_type = src1->type,
|
||||
.vectorized = use_fast ? vectorized : 0,
|
||||
.is_vec = (use_fast && is_vec) ? 1 : 0,
|
||||
.use_subgroup_matrix =
|
||||
(use_fast && !is_vec && supports_subgroup_matrix) ? 1 : 0,
|
||||
.register_tile =
|
||||
(use_fast && !is_vec && !supports_subgroup_matrix) ? 1 : 0 };
|
||||
|
||||
// Build shader context
|
||||
ggml_webgpu_mul_mat_shader_lib_context shader_lib_ctx = { .key = key,
|
||||
.max_subgroup_size =
|
||||
ctx->global_ctx->capabilities.max_subgroup_size,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k };
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline pipeline;
|
||||
const char * shader_src = nullptr;
|
||||
|
||||
auto it = ctx->mul_mat_pipelines.find(key);
|
||||
if (it != ctx->mul_mat_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
// Select appropriate shader source based on key
|
||||
if (!use_fast) {
|
||||
// Use precompiled quantized shaders (mul_mat.tmpl.wgsl)
|
||||
// These are the fallback for quantized types not supported by fast
|
||||
// paths
|
||||
shader_src = wgsl_mul_mat;
|
||||
} else {
|
||||
// Use JIT-compiled shader
|
||||
if (is_vec) {
|
||||
shader_src = wgsl_mul_mat_vec;
|
||||
} else if (key.use_subgroup_matrix) {
|
||||
shader_src = wgsl_mul_mat_subgroup_matrix;
|
||||
} else {
|
||||
shader_src = wgsl_mul_mat_reg_tile;
|
||||
}
|
||||
}
|
||||
|
||||
if (shader_src) {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_mul_mat_shader(ctx->p, shader_src, shader_lib_ctx);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants;
|
||||
if (shader_lib_ctx.key.is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
|
||||
constants.push_back({ nullptr, "WORKGROUP_SIZE", static_cast<double>(decisions->wg_size) });
|
||||
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
|
||||
constants.push_back({ nullptr, "OUTPUTS_PER_WG", static_cast<double>(decisions->outputs_per_wg) });
|
||||
} else if (shader_lib_ctx.key.register_tile) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
|
||||
constants.push_back({ nullptr, "WORKGROUP_SIZE_M", static_cast<double>(decisions->wg_size_m) });
|
||||
constants.push_back({ nullptr, "WORKGROUP_SIZE_N", static_cast<double>(decisions->wg_size_n) });
|
||||
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
|
||||
}
|
||||
// printf("DEBUG: Creating pipeline with variant='%s', "
|
||||
// "constants.size()=%zu\n",
|
||||
// processed.variant.c_str(), constants.size());
|
||||
|
||||
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
|
||||
processed.variant.c_str(), constants);
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->mul_mat_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
}
|
||||
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Build params
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src0->ne[2],
|
||||
(uint32_t) src0->ne[3],
|
||||
(uint32_t) (src1->ne[2] / src0->ne[2]),
|
||||
(uint32_t) (src1->ne[3] / src0->ne[3])
|
||||
};
|
||||
|
||||
// Build bind group entries
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
|
||||
if (decisions->is_vec) {
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
} else if (use_fast) {
|
||||
// Fast-path tiled/subgroup calculations
|
||||
uint32_t wg_m, wg_n;
|
||||
if (decisions->use_subgroup_matrix) {
|
||||
uint32_t wg_m_sg_tile =
|
||||
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
|
||||
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
||||
uint32_t wg_n_sg_tile =
|
||||
decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
|
||||
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
||||
} else {
|
||||
uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
|
||||
uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
|
||||
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
||||
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
||||
}
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
} else {
|
||||
// Non-fast-path quantized shaders (Q2_K, Q4_K, etc.)
|
||||
// Use the value from decisions instead of hardcoded constant
|
||||
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->mul_mat_wg_size);
|
||||
wg_y = 1;
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
|
|
@ -1671,6 +1810,29 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
|
|||
static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
ggml_webgpu_scale_pipeline_key key = { .inplace = inplace };
|
||||
|
||||
ggml_webgpu_scale_shader_lib_context shader_lib_ctx = {
|
||||
.key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
// TODO: remove guard once pipeline caches are per-thread
|
||||
auto it = ctx->scale_pipelines.find(key);
|
||||
if (it != ctx->scale_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_scale_shader(ctx->p, wgsl_scale, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->scale_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// params unchanged
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
|
@ -1688,12 +1850,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
|
|||
*(uint32_t *) &dst->op_params[1] // bias
|
||||
};
|
||||
|
||||
// bindgroups unchanged
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
||||
};
|
||||
|
||||
if (!inplace) {
|
||||
entries.push_back({ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
|
|
@ -1702,8 +1866,7 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
|
|||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
|
||||
entries, wg_x);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
|
||||
|
|
@ -2233,7 +2396,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
|
|||
size_t offset,
|
||||
size_t size) {
|
||||
if (size == 0) {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
|
||||
WEBGPU_LOG_DEBUG(
|
||||
"ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
|
||||
"nothing to do.");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -2310,7 +2475,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
|
||||
size_t final_size = size;
|
||||
if (size % 4 != 0) {
|
||||
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
|
||||
// If size is not a multiple of 4, we need to round it up to the next
|
||||
// multiple of 4
|
||||
final_size = size + (4 - (size % 4));
|
||||
}
|
||||
|
||||
|
|
@ -2364,7 +2530,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
|
|||
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
|
||||
/* .clear = */ ggml_backend_webgpu_buffer_clear,
|
||||
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
|
||||
/* .reset = */ NULL, // TODO: optional, think it coordinates with
|
||||
// .init_tensor
|
||||
};
|
||||
|
||||
/* End GGML Backend Buffer Interface */
|
||||
|
|
@ -2401,7 +2568,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_
|
|||
return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
}
|
||||
|
||||
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
|
||||
// maxBufferSize might be larger, but you can't bind more than
|
||||
// maxStorageBufferBindingSize to a single binding.
|
||||
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_webgpu_device_context * dev_ctx =
|
||||
static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
|
|
@ -2487,14 +2655,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|||
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
||||
}
|
||||
|
||||
// Workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = wg_size;
|
||||
return constants;
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
||||
// we use the maximum workgroup size for the memset pipeline
|
||||
size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
|
@ -2509,207 +2669,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
|||
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
// Q4/Q5/Q8 classic quantizations
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
|
||||
|
||||
// K-quantizations
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
|
||||
|
||||
// IQ quantizations (2-, 3-, 4-bit variants)
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
|
||||
|
||||
// 1-bit and 4-bit IQ variants
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
||||
|
||||
std::string proc_mul_mat_f32_f32;
|
||||
std::string proc_mul_mat_f32_f32_vec;
|
||||
std::string proc_mul_mat_f16_f32;
|
||||
std::string proc_mul_mat_f16_f32_vec;
|
||||
std::string proc_mul_mat_f16_f16;
|
||||
std::string proc_mul_mat_f16_f16_vec;
|
||||
std::string proc_mul_mat_q4_0_f32;
|
||||
std::string proc_mul_mat_q4_0_f32_vec;
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_constants;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
std::map<std::string, std::string> sg_matrix_repls;
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
|
||||
std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
|
||||
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
|
||||
proc_mul_mat_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
|
||||
proc_mul_mat_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
|
||||
} else {
|
||||
#endif
|
||||
mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
|
||||
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
|
||||
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
|
||||
|
||||
std::map<std::string, std::string> reg_repls;
|
||||
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
|
||||
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
|
||||
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||
proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||
proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||
proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||
proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||
proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||
#ifndef __EMSCRIPTEN__
|
||||
}
|
||||
#endif
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
|
||||
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
|
||||
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
mul_mat_vec_constants[1].key = "TILE_K";
|
||||
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
|
||||
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
|
||||
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
|
||||
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
|
||||
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
|
||||
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
|
|
@ -2816,15 +2775,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
|||
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
webgpu_ctx->scale_pipelines[0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
|
||||
webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
|
||||
"scale_f32_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
|
||||
|
|
@ -3023,13 +2973,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_scale_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
|
|
@ -3071,11 +3018,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
|||
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ NULL, // defaults to false
|
||||
/* .alloc_buffer = */
|
||||
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
||||
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
||||
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
||||
},
|
||||
/* .device = */
|
||||
dev,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
#decl(BYTE_HELPERS)
|
||||
|
||||
#ifdef BYTE_HELPERS
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
|
@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
|
|||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(BYTE_HELPERS)
|
||||
|
||||
#decl(Q4_0_T)
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#enddecl(Q4_0_T)
|
||||
#endif
|
||||
|
||||
#decl(Q4_1_T)
|
||||
#ifdef Q4_1_T
|
||||
struct q4_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: array<u32, 4>
|
||||
};
|
||||
#enddecl(Q4_1_T)
|
||||
#endif
|
||||
|
||||
#decl(Q5_0_T)
|
||||
#ifdef Q5_0_T
|
||||
struct q5_0 {
|
||||
d: f16,
|
||||
qh: array<f16, 2>,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#enddecl(Q5_0_T)
|
||||
#endif
|
||||
|
||||
#decl(Q5_1_T)
|
||||
#ifdef Q5_1_T
|
||||
struct q5_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qh: u32,
|
||||
qs: array<u32, 4>
|
||||
};
|
||||
#enddecl(Q5_1_T)
|
||||
#endif
|
||||
|
||||
#decl(Q8_0_T)
|
||||
#ifdef Q8_0_T
|
||||
struct q8_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 16>
|
||||
};
|
||||
#enddecl(Q8_0_T)
|
||||
#endif
|
||||
|
||||
#decl(Q8_1_T)
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: array<u32, 8>
|
||||
};
|
||||
#enddecl(Q8_1_T)
|
||||
#endif
|
||||
|
||||
#decl(Q2_K_T)
|
||||
struct q2_k {
|
||||
#ifdef Q2_K_T
|
||||
struct q2_K {
|
||||
scales: array<u32, 4>,
|
||||
qs: array<u32, 16>,
|
||||
d: f16,
|
||||
dmin: f16
|
||||
};
|
||||
#enddecl(Q2_K_T)
|
||||
#endif
|
||||
|
||||
#decl(Q3_K_T)
|
||||
struct q3_k {
|
||||
#ifdef Q3_K_T
|
||||
struct q3_K {
|
||||
hmask: array<f16, 16>,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 6>,
|
||||
d: f16
|
||||
};
|
||||
#enddecl(Q3_K_T)
|
||||
|
||||
#decl(Q45_K_SCALE_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
|
||||
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
||||
if (is < 4) {
|
||||
let sc_byte = get_byte(scales[is / 4], is % 4);
|
||||
|
|
@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
|||
return vec2(f32(sc), f32(m));
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(Q45_K_SCALE_MIN)
|
||||
|
||||
#decl(Q4_K_T)
|
||||
struct q4_k {
|
||||
#endif
|
||||
#ifdef Q4_K_T
|
||||
struct q4_K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: array<u32, 3>,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(Q4_K_T)
|
||||
#endif
|
||||
|
||||
#decl(Q5_K_T)
|
||||
struct q5_k {
|
||||
#ifdef Q5_K_T
|
||||
struct q5_K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: array<u32, 3>,
|
||||
qh: array<u32, 8>,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(Q5_K_T)
|
||||
#endif
|
||||
|
||||
#decl(Q6_K_T)
|
||||
struct q6_k {
|
||||
#ifdef Q6_K_T
|
||||
struct q6_K {
|
||||
ql: array<f16, 64>,
|
||||
qh: array<f16, 32>,
|
||||
scales: array<f16, 8>,
|
||||
d: f16
|
||||
};
|
||||
#enddecl(Q6_K_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XXS_T)
|
||||
#ifdef IQ2_XXS_T
|
||||
struct iq2_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>
|
||||
};
|
||||
#enddecl(IQ2_XXS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XS_T)
|
||||
#ifdef IQ2_XS_T
|
||||
struct iq2_xs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#enddecl(IQ2_XS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_S_T)
|
||||
#ifdef IQ2_S_T
|
||||
struct iq2_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#enddecl(IQ2_S_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ3_XSS_T)
|
||||
#ifdef IQ3_XXS_T
|
||||
struct iq3_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 48>
|
||||
};
|
||||
#enddecl(IQ3_XSS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ3_S_T)
|
||||
#ifdef IQ3_S_T
|
||||
struct iq3_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
|
|
@ -161,41 +156,41 @@ struct iq3_s {
|
|||
signs: array<f16, 16>,
|
||||
scales: array<f16, 2>
|
||||
};
|
||||
#enddecl(IQ3_S_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_S_T)
|
||||
#ifdef IQ1_S_T
|
||||
struct iq1_s {
|
||||
d: f16,
|
||||
qs: array<f16, 16>,
|
||||
qh: array<f16, 8>
|
||||
};
|
||||
#enddecl(IQ1_S_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_M_T)
|
||||
#ifdef IQ1_M_T
|
||||
struct iq1_m {
|
||||
qs: array<u32, 8>,
|
||||
qh: array<u32, 4>,
|
||||
scales: array<u32, 2>
|
||||
};
|
||||
#enddecl(IQ1_M_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_NL_T)
|
||||
#ifdef IQ4_NL_T
|
||||
struct iq4_nl {
|
||||
d: f16,
|
||||
qs: array<f16, 8>,
|
||||
};
|
||||
#enddecl(IQ4_NL_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_XS_T)
|
||||
#ifdef IQ4_XS_T
|
||||
struct iq4_xs {
|
||||
d: f16,
|
||||
scales_h: f16,
|
||||
scales_l: u32,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(IQ4_XS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ23_TABLES)
|
||||
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
|
||||
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
|
||||
0x08040201u, // 1, 2, 4, 8
|
||||
0x80402010u // 16, 32, 64, 128
|
||||
|
|
@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
|
|||
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
|
||||
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
|
||||
);
|
||||
#enddecl(IQ23_TABLES)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XXS_GRID)
|
||||
#ifdef IQ2_XXS_GRID
|
||||
const iq2xxs_grid = array<u32, 512>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
|
||||
|
|
@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
|
|||
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
|
||||
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
|
||||
);
|
||||
#enddecl(IQ2_XXS_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XS_GRID)
|
||||
#ifdef IQ2_XS_GRID
|
||||
const iq2xs_grid = array<u32, 1024>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
|
||||
|
|
@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
|
|||
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
|
||||
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
|
||||
);
|
||||
#enddecl(IQ2_XS_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_S_GRID)
|
||||
#ifdef IQ2_S_GRID
|
||||
const iq2s_grid = array<u32, 2048>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
|
||||
|
|
@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
|
|||
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
|
||||
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
|
||||
);
|
||||
#enddecl(IQ2_S_GRID)
|
||||
|
||||
#decl(IQ3_XSS_GRID)
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS_GRID
|
||||
const iq3xxs_grid = array<u32, 256>(
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
|
|
@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
|
|||
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
||||
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
|
||||
);
|
||||
#enddecl(IQ3_XSS_GRID)
|
||||
|
||||
#decl(IQ3_S_GRID)
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S_GRID
|
||||
const iq3s_grid = array<u32, 512>(
|
||||
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
||||
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
||||
|
|
@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
|
|||
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
|
||||
);
|
||||
#enddecl(IQ3_S_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_GRID)
|
||||
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
|
||||
|
||||
const IQ1_DELTA: f32 = 0.125;
|
||||
|
||||
|
|
@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
|
|||
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
|
||||
);
|
||||
|
||||
#enddecl(IQ1_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_GRID)
|
||||
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
|
||||
|
||||
const kvalues_iq4nl = array<i32, 16>(
|
||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
||||
);
|
||||
|
||||
#enddecl(IQ4_GRID)
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -56,7 +56,9 @@ def expand_includes(shader, input_dir):
|
|||
return include_pattern.sub(replacer, shader)
|
||||
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
|
||||
shader_code = expand_includes(shader_code, input_dir)
|
||||
|
||||
if output_dir:
|
||||
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
|
||||
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
|
||||
|
|
@ -74,7 +76,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||
try:
|
||||
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
|
||||
except ValueError:
|
||||
write_shader(shader_base_name, text, output_dir, outfile)
|
||||
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
|
||||
else:
|
||||
try:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
|
|
@ -123,7 +125,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
||||
else:
|
||||
output_name = shader_base_name
|
||||
write_shader(output_name, final_shader, output_dir, outfile)
|
||||
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -1,222 +1,31 @@
|
|||
#define(VARIANTS)
|
||||
enable f16;
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_vec",
|
||||
"REPLS": {
|
||||
"TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"BLOCK_SIZE": 4
|
||||
},
|
||||
"DECLS": ["F32_VEC"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["F32"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["F16"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "i32",
|
||||
"DST_TYPE": "i32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["I32"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_1",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_1",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q8_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q2_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q3_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q6_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "iq2_xxs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "iq2_xs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq2_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq3_xxs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq3_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq1_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq1_m",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq4_nl",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq4_xs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(F32_VEC)
|
||||
#ifdef F32_VEC
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
|
||||
}
|
||||
#enddecl(F32_VEC)
|
||||
#endif
|
||||
|
||||
#decl(F32)
|
||||
#ifdef F32
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = src[src_base + offset];
|
||||
}
|
||||
#enddecl(F32)
|
||||
#endif
|
||||
|
||||
#decl(F16)
|
||||
#ifdef F16
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = f32(src[src_base + offset]);
|
||||
}
|
||||
#enddecl(F16)
|
||||
#endif
|
||||
|
||||
#decl(I32)
|
||||
#ifdef I32
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = src[src_base + offset];
|
||||
}
|
||||
#enddecl(I32)
|
||||
#endif
|
||||
|
||||
#decl(Q4_0)
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_0 = src[src_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
|
|
@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_0)
|
||||
#endif
|
||||
|
||||
#decl(Q4_1)
|
||||
#ifdef Q4_1
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_1 = src[src_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
|
|
@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_1)
|
||||
#endif
|
||||
|
||||
#decl(Q5_0)
|
||||
#ifdef Q5_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_0 = src[src_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
|
|
@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q5_0)
|
||||
|
||||
#decl(Q5_1)
|
||||
#ifdef Q5_1
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_1 = src[src_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
|
|
@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q5_1)
|
||||
#endif
|
||||
|
||||
#decl(Q8_0)
|
||||
#ifdef Q8_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q8_0 = src[src_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
|
|
@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q8_0)
|
||||
#endif
|
||||
|
||||
#decl(Q2_K)
|
||||
#ifdef Q2_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q2_K)
|
||||
#endif
|
||||
|
||||
#decl(Q3_K)
|
||||
#ifdef Q3_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q3_K)
|
||||
#endif
|
||||
|
||||
#decl(Q4_K)
|
||||
#ifdef Q4_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
|
@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_K)
|
||||
#endif
|
||||
|
||||
#decl(Q5_K)
|
||||
#ifdef Q5_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q5_K)
|
||||
#endif
|
||||
|
||||
#decl(Q6_K)
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
|
@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
sc_b_idx += 8;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q6_K)
|
||||
|
||||
#decl(IQ2_XXS)
|
||||
#ifdef IQ2_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ2_XXS)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XS)
|
||||
#ifdef IQ2_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ2_XS)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_S)
|
||||
#ifdef IQ2_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ2_S)
|
||||
|
||||
#decl(IQ3_XSS)
|
||||
#ifdef IQ3_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ3_XSS)
|
||||
#endif
|
||||
|
||||
#decl(IQ3_S)
|
||||
#ifdef IQ3_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ3_S)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_S)
|
||||
#ifdef IQ1_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_S)
|
||||
|
||||
#decl(IQ1_M)
|
||||
#ifdef IQ1_M
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
||||
|
|
@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_M)
|
||||
|
||||
#decl(IQ4_NL)
|
||||
#ifdef IQ4_NL
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
dst_i++;
|
||||
}
|
||||
}
|
||||
#enddecl(IQ4_NL)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_XS)
|
||||
#ifdef IQ4_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
dst_i += 16;
|
||||
}
|
||||
}
|
||||
#enddecl(IQ4_XS)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{TYPE}}>;
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx: array<i32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||
var<storage, read_write> dst: array<DST_TYPE>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
|
|
@ -866,9 +662,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
|
||||
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
|
||||
|
||||
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
|
||||
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
|
||||
copy_elements(i_src_row, i_dst_row, i);
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,195 +1,24 @@
|
|||
#define(VARIANTS)
|
||||
enable f16;
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"BLOCK_SIZE" : 1
|
||||
},
|
||||
"DECLS" : ["FLOAT"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"BLOCK_SIZE" : 1
|
||||
},
|
||||
"DECLS" : ["FLOAT"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"BLOCK_SIZE" : 1
|
||||
},
|
||||
"DECLS" : ["FLOAT"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q4_0",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q4_1",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q5_0",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q5_1",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q8_0",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q2_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q3_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q4_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q5_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q6_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq2_xxs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq2_xs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq2_s",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq3_xxs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq3_s",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq1_s",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq1_m",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq4_nl",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq4_xs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
|
||||
}
|
||||
]
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#end(VARIANTS)
|
||||
#ifdef FLOAT
|
||||
const BLOCK_SIZE = 1u;
|
||||
|
||||
#define(DECLS)
|
||||
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
|
||||
const BLOCK_SIZE = 32u;
|
||||
|
||||
#decl(FLOAT)
|
||||
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
|
||||
const BLOCK_SIZE = 256u;
|
||||
#endif
|
||||
|
||||
#ifdef FLOAT
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
|
||||
}
|
||||
#enddecl(FLOAT)
|
||||
#endif
|
||||
|
||||
#decl(Q4_0)
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
|
|
@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q4_0)
|
||||
#endif
|
||||
|
||||
#decl(Q4_1)
|
||||
#ifdef Q4_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
|
|
@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q4_1)
|
||||
#endif
|
||||
|
||||
#decl(Q5_0)
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
|
|
@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q5_0)
|
||||
#endif
|
||||
|
||||
#decl(Q5_1)
|
||||
#ifdef Q5_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
|
|
@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q5_1)
|
||||
#endif
|
||||
|
||||
#decl(Q8_0)
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
|
|
@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q8_0)
|
||||
#endif
|
||||
|
||||
#decl(Q8_1)
|
||||
#ifdef Q8_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_1.d);
|
||||
|
|
@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q8_1)
|
||||
#endif
|
||||
|
||||
#decl(Q2_K)
|
||||
#ifdef Q2_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q2_K)
|
||||
|
||||
#decl(Q3_K)
|
||||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q3_K)
|
||||
|
||||
#decl(Q4_K)
|
||||
#ifdef Q4_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q4_K)
|
||||
|
||||
#decl(Q5_K)
|
||||
#ifdef Q5_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q5_K)
|
||||
|
||||
#decl(Q6_K)
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q6_K)
|
||||
|
||||
#decl(IQ2_XXS)
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ2_XXS)
|
||||
|
||||
#decl(IQ2_XS)
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ2_XS)
|
||||
|
||||
#decl(IQ2_S)
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#enddecl(IQ2_S)
|
||||
|
||||
#decl(IQ3_XSS)
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ3_XSS)
|
||||
|
||||
#decl(IQ3_S)
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(IQ3_S)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_S)
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_S)
|
||||
|
||||
#decl(IQ1_M)
|
||||
#ifdef IQ1_M
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
||||
|
|
@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_M)
|
||||
|
||||
#decl(IQ4_NL)
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ4_NL)
|
||||
|
||||
#decl(IQ4_XS)
|
||||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
#enddecl(IQ4_XS)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32, // in elements/blocks
|
||||
|
|
@ -864,8 +672,8 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
|
@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
|
||||
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
|
||||
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
||||
}
|
||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,58 +1,53 @@
|
|||
#decl(SHMEM_VEC)
|
||||
#ifdef SHMEM_VEC
|
||||
fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx] = val.x;
|
||||
shmem[idx + 1] = val.y;
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
#enddecl(SHMEM_VEC)
|
||||
#endif
|
||||
|
||||
#decl(SHMEM_SCALAR)
|
||||
#ifdef SHMEM_SCALAR
|
||||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
#enddecl(SHMEM_SCALAR)
|
||||
|
||||
#decl(INIT_SRC0_SHMEM_FLOAT)
|
||||
#endif
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_FLOAT
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let src0_val = select( // taking a slight performance hit to avoid oob
|
||||
{{SRC0_TYPE}}(0.0),
|
||||
src0[src0_idx/{{VEC_SIZE}}],
|
||||
SRC0_TYPE(0.0),
|
||||
src0[src0_idx/VEC_SIZE],
|
||||
global_m < params.m && global_k < params.k);
|
||||
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
|
||||
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(INIT_SRC0_SHMEM_FLOAT)
|
||||
|
||||
#decl(INIT_SRC1_SHMEM)
|
||||
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let tile_n = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_n = offset_n + tile_n;
|
||||
let global_k = k_outer + tile_k;
|
||||
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
|
||||
let src1_val = select(
|
||||
{{SRC1_TYPE}}(0.0),
|
||||
src1[src1_idx/{{VEC_SIZE}}],
|
||||
SRC1_TYPE(0.0),
|
||||
src1[src1_idx/VEC_SIZE],
|
||||
global_n < params.n && global_k < params.k);
|
||||
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(INIT_SRC1_SHMEM)
|
||||
|
||||
#decl(INIT_SRC0_SHMEM_Q4_0)
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
|
|
@ -93,5 +88,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(INIT_SRC0_SHMEM_Q4_0)
|
||||
#endif
|
||||
|
|
@ -1,115 +1,19 @@
|
|||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
}
|
||||
]
|
||||
enable f16;
|
||||
|
||||
#end(VARIANTS)
|
||||
#include "common_decls.tmpl"
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
#ifdef VEC
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
|
||||
}
|
||||
#enddecl(VEC)
|
||||
#endif
|
||||
|
||||
#decl(SCALAR)
|
||||
#ifdef SCALAR
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||
return f32(acc[tm][tn]);
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -130,14 +34,13 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
DECLS
|
||||
|
||||
fn get_local_n(thread_id: u32) -> u32 {
|
||||
return thread_id / WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
|
@ -145,10 +48,6 @@ fn get_local_m(thread_id: u32) -> u32 {
|
|||
return thread_id % WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
||||
// TILE_M must be multiple of 4 for vec4 loads
|
||||
const TILE_M = {{WEBGPU_TILE_M}}u;
|
||||
const TILE_N = {{WEBGPU_TILE_N}}u;
|
||||
|
||||
override WORKGROUP_SIZE_M: u32;
|
||||
override WORKGROUP_SIZE_N: u32;
|
||||
override TILE_K: u32;
|
||||
|
|
@ -233,15 +132,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||
let global_col = output_col_base + tn;
|
||||
if (global_col < params.n) {
|
||||
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
|
||||
for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
|
||||
let global_row = output_row_base + tm;
|
||||
if (global_row < params.m) {
|
||||
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
|
||||
dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,100 +1,12 @@
|
|||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
}
|
||||
]
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#end(VARIANTS)
|
||||
#include "common_decls.tmpl"
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
#ifdef VEC
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = vec4<f32>(
|
||||
f32(shmem[shmem_idx]),
|
||||
|
|
@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
|||
f32(shmem[shmem_idx + 3])
|
||||
);
|
||||
}
|
||||
#enddecl(VEC)
|
||||
#endif
|
||||
|
||||
#decl(SCALAR)
|
||||
#ifdef SCALAR
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = f32(shmem[shmem_idx]);
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -138,36 +42,19 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
DECLS
|
||||
|
||||
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
|
||||
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
|
||||
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
|
||||
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
|
||||
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
|
||||
// runtime subgroup size is smaller.
|
||||
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
|
||||
|
||||
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
|
||||
|
||||
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
|
||||
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
|
||||
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
|
||||
|
||||
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
|
||||
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
|
||||
|
||||
const TILE_K = {{WEBGPU_TILE_K}}u;
|
||||
|
||||
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
|
||||
// runtime subgroup size is smaller.
|
||||
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
|
||||
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
|
||||
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
|
@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let local_row = idx % WG_TILE_STRIDE;
|
||||
let local_col = idx / WG_TILE_STRIDE;
|
||||
|
||||
|
|
@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (global_col < params.n && global_row < params.m) {
|
||||
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||
store_dst(idx, dst_idx/{{VEC_SIZE}});
|
||||
store_dst(idx, dst_idx/VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,84 +1,11 @@
|
|||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
enable f16;
|
||||
|
||||
#define(DECLS)
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#decl(VEC)
|
||||
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
|
||||
#ifdef VEC
|
||||
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
|
||||
return f32(dot(SRC1_TYPE(src0_val), src1_val));
|
||||
}
|
||||
|
||||
fn store_val(group_base: u32) -> vec4<f32> {
|
||||
|
|
@ -87,33 +14,31 @@ fn store_val(group_base: u32) -> vec4<f32> {
|
|||
partial_sums[group_base + THREADS_PER_OUTPUT * 2],
|
||||
partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
|
||||
}
|
||||
#enddecl(VEC)
|
||||
#endif
|
||||
|
||||
#decl(SCALAR)
|
||||
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||
#ifdef SCALAR
|
||||
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
|
||||
return f32(src0_val) * f32(src1_val);
|
||||
}
|
||||
|
||||
fn store_val(group_base: u32) -> f32 {
|
||||
return partial_sums[group_base];
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#decl(MUL_ACC_FLOAT)
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_FLOAT
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
|
||||
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
|
||||
let b = shared_vector[i / {{VEC_SIZE}}];
|
||||
for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
|
||||
let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
|
||||
let b = shared_vector[i / VEC_SIZE];
|
||||
local_sum += inner_dot(a, b);
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(MUL_ACC_FLOAT)
|
||||
|
||||
#decl(MUL_ACC_Q4_0)
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
|
|
@ -145,15 +70,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
}
|
||||
return local_sum;
|
||||
}
|
||||
|
||||
#enddecl(MUL_ACC_Q4_0)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -174,9 +91,10 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
|
||||
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
|
|
@ -186,7 +104,7 @@ override OUTPUTS_PER_WG: u32;
|
|||
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
|
||||
|
||||
// Shared memory for collaborative loading and reduction
|
||||
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
|
||||
var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
|
||||
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
|
||||
|
||||
@compute @workgroup_size(WORKGROUP_SIZE)
|
||||
|
|
@ -232,8 +150,8 @@ fn main(
|
|||
let tile_size = min(TILE_K, params.k - k_tile);
|
||||
|
||||
// Cooperatively load vector tile into shared memory (all threads)
|
||||
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
|
||||
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WORKGROUP_SIZE * VEC_SIZE) {
|
||||
shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -260,8 +178,8 @@ fn main(
|
|||
}
|
||||
|
||||
// Store back to global memory
|
||||
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
|
||||
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
|
||||
if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
|
||||
dst[dst_idx / VEC_SIZE] = store_val(group_base);
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
||||
|
|
@ -1,21 +1,11 @@
|
|||
#define(VARIANTS)
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "scale_f32",
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "scale_f32_inplace",
|
||||
"DECLS": ["INPLACE"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
fn store_scale(val: f32, offset: u32) {
|
||||
src[offset] = val;
|
||||
}
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
|
|
@ -25,20 +15,7 @@ var<uniform> params: Params;
|
|||
fn store_scale(val: f32, offset: u32) {
|
||||
dst[offset] = val;
|
||||
}
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn store_scale(val: f32, offset: u32) {
|
||||
src[offset] = val;
|
||||
}
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src: u32,
|
||||
|
|
@ -65,10 +42,7 @@ struct Params {
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
|
|
@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
|
||||
store_scale(src[i_src] * params.scale + params.bias, i_dst);
|
||||
}
|
||||
#end(SHADER)
|
||||
Loading…
Reference in New Issue