This commit is contained in:
Reese Levine 2026-02-17 09:53:04 +09:00 committed by GitHub
commit ece93a4ce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1647 additions and 2016 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,4 @@
#decl(BYTE_HELPERS) #ifdef BYTE_HELPERS
fn get_byte(value: u32, index: u32) -> u32 { fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF; return (value >> (index * 8)) & 0xFF;
} }
@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
fn get_byte_i32(value: u32, index: u32) -> i32 { fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
} }
#endif
#enddecl(BYTE_HELPERS) #ifdef Q4_0_T
#decl(Q4_0_T)
struct q4_0 { struct q4_0 {
d: f16, d: f16,
qs: array<f16, 8> qs: array<f16, 8>
}; };
#enddecl(Q4_0_T) #endif
#decl(Q4_1_T) #ifdef Q4_1_T
struct q4_1 { struct q4_1 {
d: f16, d: f16,
m: f16, m: f16,
qs: array<u32, 4> qs: array<u32, 4>
}; };
#enddecl(Q4_1_T) #endif
#decl(Q5_0_T) #ifdef Q5_0_T
struct q5_0 { struct q5_0 {
d: f16, d: f16,
qh: array<f16, 2>, qh: array<f16, 2>,
qs: array<f16, 8> qs: array<f16, 8>
}; };
#enddecl(Q5_0_T) #endif
#decl(Q5_1_T) #ifdef Q5_1_T
struct q5_1 { struct q5_1 {
d: f16, d: f16,
m: f16, m: f16,
qh: u32, qh: u32,
qs: array<u32, 4> qs: array<u32, 4>
}; };
#enddecl(Q5_1_T) #endif
#decl(Q8_0_T) #ifdef Q8_0_T
struct q8_0 { struct q8_0 {
d: f16, d: f16,
qs: array<f16, 16> qs: array<f16, 16>
}; };
#enddecl(Q8_0_T) #endif
#decl(Q8_1_T) #ifdef Q8_1_T
struct q8_1 { struct q8_1 {
d: f16, d: f16,
m: f16, m: f16,
qs: array<u32, 8> qs: array<u32, 8>
}; };
#enddecl(Q8_1_T) #endif
#decl(Q2_K_T) #ifdef Q2_K_T
struct q2_k { struct q2_K {
scales: array<u32, 4>, scales: array<u32, 4>,
qs: array<u32, 16>, qs: array<u32, 16>,
d: f16, d: f16,
dmin: f16 dmin: f16
}; };
#enddecl(Q2_K_T) #endif
#decl(Q3_K_T) #ifdef Q3_K_T
struct q3_k { struct q3_K {
hmask: array<f16, 16>, hmask: array<f16, 16>,
qs: array<f16, 32>, qs: array<f16, 32>,
scales: array<f16, 6>, scales: array<f16, 6>,
d: f16 d: f16
}; };
#enddecl(Q3_K_T) #endif
#decl(Q45_K_SCALE_MIN)
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> { fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
if (is < 4) { if (is < 4) {
let sc_byte = get_byte(scales[is / 4], is % 4); let sc_byte = get_byte(scales[is / 4], is % 4);
@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
return vec2(f32(sc), f32(m)); return vec2(f32(sc), f32(m));
} }
} }
#endif
#enddecl(Q45_K_SCALE_MIN) #ifdef Q4_K_T
struct q4_K {
#decl(Q4_K_T)
struct q4_k {
d: f16, d: f16,
dmin: f16, dmin: f16,
scales: array<u32, 3>, scales: array<u32, 3>,
qs: array<u32, 32> qs: array<u32, 32>
}; };
#enddecl(Q4_K_T) #endif
#decl(Q5_K_T) #ifdef Q5_K_T
struct q5_k { struct q5_K {
d: f16, d: f16,
dmin: f16, dmin: f16,
scales: array<u32, 3>, scales: array<u32, 3>,
qh: array<u32, 8>, qh: array<u32, 8>,
qs: array<u32, 32> qs: array<u32, 32>
}; };
#enddecl(Q5_K_T) #endif
#decl(Q6_K_T) #ifdef Q6_K_T
struct q6_k { struct q6_K {
ql: array<f16, 64>, ql: array<f16, 64>,
qh: array<f16, 32>, qh: array<f16, 32>,
scales: array<f16, 8>, scales: array<f16, 8>,
d: f16 d: f16
}; };
#enddecl(Q6_K_T) #endif
#decl(IQ2_XXS_T) #ifdef IQ2_XXS_T
struct iq2_xxs { struct iq2_xxs {
d: f16, d: f16,
qs: array<f16, 32> qs: array<f16, 32>
}; };
#enddecl(IQ2_XXS_T) #endif
#decl(IQ2_XS_T) #ifdef IQ2_XS_T
struct iq2_xs { struct iq2_xs {
d: f16, d: f16,
qs: array<f16, 32>, qs: array<f16, 32>,
scales: array<f16, 4> scales: array<f16, 4>
}; };
#enddecl(IQ2_XS_T) #endif
#decl(IQ2_S_T) #ifdef IQ2_S_T
struct iq2_s { struct iq2_s {
d: f16, d: f16,
qs: array<f16, 32>, qs: array<f16, 32>,
qh: array<f16, 4>, qh: array<f16, 4>,
scales: array<f16, 4> scales: array<f16, 4>
}; };
#enddecl(IQ2_S_T) #endif
#decl(IQ3_XSS_T) #ifdef IQ3_XXS_T
struct iq3_xxs { struct iq3_xxs {
d: f16, d: f16,
qs: array<f16, 48> qs: array<f16, 48>
}; };
#enddecl(IQ3_XSS_T) #endif
#decl(IQ3_S_T) #ifdef IQ3_S_T
struct iq3_s { struct iq3_s {
d: f16, d: f16,
qs: array<f16, 32>, qs: array<f16, 32>,
@ -161,41 +156,41 @@ struct iq3_s {
signs: array<f16, 16>, signs: array<f16, 16>,
scales: array<f16, 2> scales: array<f16, 2>
}; };
#enddecl(IQ3_S_T) #endif
#decl(IQ1_S_T) #ifdef IQ1_S_T
struct iq1_s { struct iq1_s {
d: f16, d: f16,
qs: array<f16, 16>, qs: array<f16, 16>,
qh: array<f16, 8> qh: array<f16, 8>
}; };
#enddecl(IQ1_S_T) #endif
#decl(IQ1_M_T) #ifdef IQ1_M_T
struct iq1_m { struct iq1_m {
qs: array<u32, 8>, qs: array<u32, 8>,
qh: array<u32, 4>, qh: array<u32, 4>,
scales: array<u32, 2> scales: array<u32, 2>
}; };
#enddecl(IQ1_M_T) #endif
#decl(IQ4_NL_T) #ifdef IQ4_NL_T
struct iq4_nl { struct iq4_nl {
d: f16, d: f16,
qs: array<f16, 8>, qs: array<f16, 8>,
}; };
#enddecl(IQ4_NL_T) #endif
#decl(IQ4_XS_T) #ifdef IQ4_XS_T
struct iq4_xs { struct iq4_xs {
d: f16, d: f16,
scales_h: f16, scales_h: f16,
scales_l: u32, scales_l: u32,
qs: array<u32, 32> qs: array<u32, 32>
}; };
#enddecl(IQ4_XS_T) #endif
#decl(IQ23_TABLES) #if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
const kmask_iq2xs : array<u32, 2> = array<u32, 2>( const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
0x08040201u, // 1, 2, 4, 8 0x08040201u, // 1, 2, 4, 8
0x80402010u // 16, 32, 64, 128 0x80402010u // 16, 32, 64, 128
@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
); );
#enddecl(IQ23_TABLES) #endif
#decl(IQ2_XXS_GRID) #ifdef IQ2_XXS_GRID
const iq2xxs_grid = array<u32, 512>( const iq2xxs_grid = array<u32, 512>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
); );
#enddecl(IQ2_XXS_GRID) #endif
#decl(IQ2_XS_GRID) #ifdef IQ2_XS_GRID
const iq2xs_grid = array<u32, 1024>( const iq2xs_grid = array<u32, 1024>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
); );
#enddecl(IQ2_XS_GRID) #endif
#decl(IQ2_S_GRID) #ifdef IQ2_S_GRID
const iq2s_grid = array<u32, 2048>( const iq2s_grid = array<u32, 2048>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
); );
#enddecl(IQ2_S_GRID) #endif
#decl(IQ3_XSS_GRID)
#ifdef IQ3_XXS_GRID
const iq3xxs_grid = array<u32, 256>( const iq3xxs_grid = array<u32, 256>(
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
); );
#enddecl(IQ3_XSS_GRID) #endif
#decl(IQ3_S_GRID)
#ifdef IQ3_S_GRID
const iq3s_grid = array<u32, 512>( const iq3s_grid = array<u32, 512>(
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
); );
#enddecl(IQ3_S_GRID) #endif
#decl(IQ1_GRID) #if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
const IQ1_DELTA: f32 = 0.125; const IQ1_DELTA: f32 = 0.125;
@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
); );
#enddecl(IQ1_GRID) #endif
#decl(IQ4_GRID) #if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
const kvalues_iq4nl = array<i32, 16>( const kvalues_iq4nl = array<i32, 16>(
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
); );
#enddecl(IQ4_GRID) #endif

View File

@ -56,7 +56,9 @@ def expand_includes(shader, input_dir):
return include_pattern.sub(replacer, shader) return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile): def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
shader_code = expand_includes(shader_code, input_dir)
if output_dir: if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out: with open(wgsl_filename, "w", encoding="utf-8") as f_out:
@ -74,7 +76,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
try: try:
variants = ast.literal_eval(extract_block(text, "VARIANTS")) variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError: except ValueError:
write_shader(shader_base_name, text, output_dir, outfile) write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else: else:
try: try:
decls_map = parse_decls(extract_block(text, "DECLS")) decls_map = parse_decls(extract_block(text, "DECLS"))
@ -123,7 +125,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else: else:
output_name = shader_base_name output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile) write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main(): def main():

View File

@ -1,222 +1,31 @@
#define(VARIANTS) enable f16;
#include "common_decls.tmpl"
[ #ifdef F32_VEC
{
"SHADER_SUFFIX": "f32_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"BLOCK_SIZE": 4
},
"DECLS": ["F32_VEC"]
},
{
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F32"]
},
{
"REPLS": {
"TYPE" : "f16",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F16"]
},
{
"REPLS": {
"TYPE" : "i32",
"DST_TYPE": "i32",
"BLOCK_SIZE": 1
},
"DECLS": ["I32"]
},
{
"REPLS": {
"TYPE" : "q4_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"TYPE" : "q4_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"TYPE" : "q5_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"TYPE" : "q5_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"TYPE" : "q8_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"TYPE" : "q2_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"TYPE" : "q3_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"TYPE" : "q4_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"TYPE" : "q5_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"TYPE" : "q6_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"TYPE" : "iq2_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"TYPE" : "iq2_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"TYPE": "iq2_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"TYPE": "iq3_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"TYPE": "iq3_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"TYPE": "iq1_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"TYPE": "iq1_m",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"TYPE": "iq4_nl",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"TYPE": "iq4_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(F32_VEC)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
} }
#enddecl(F32_VEC) #endif
#decl(F32) #ifdef F32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset]; dst[dst_base + offset] = src[src_base + offset];
} }
#enddecl(F32) #endif
#decl(F16) #ifdef F16
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = f32(src[src_base + offset]); dst[dst_base + offset] = f32(src[src_base + offset]);
} }
#enddecl(F16) #endif
#decl(I32) #ifdef I32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset]; dst[dst_base + offset] = src[src_base + offset];
} }
#enddecl(I32) #endif
#decl(Q4_0) #ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_0 = src[src_base + offset]; let block_q4_0 = src[src_base + offset];
let d = f32(block_q4_0.d); let d = f32(block_q4_0.d);
@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q4_0) #endif
#decl(Q4_1) #ifdef Q4_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_1 = src[src_base + offset]; let block_q4_1 = src[src_base + offset];
let d = f32(block_q4_1.d); let d = f32(block_q4_1.d);
@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q4_1) #endif
#decl(Q5_0) #ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_0 = src[src_base + offset]; let block_q5_0 = src[src_base + offset];
let d = f32(block_q5_0.d); let d = f32(block_q5_0.d);
@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#endif
#enddecl(Q5_0) #ifdef Q5_1
#decl(Q5_1)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_1 = src[src_base + offset]; let block_q5_1 = src[src_base + offset];
let d = f32(block_q5_1.d); let d = f32(block_q5_1.d);
@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q5_1) #endif
#decl(Q8_0) #ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q8_0 = src[src_base + offset]; let block_q8_0 = src[src_base + offset];
let d = f32(block_q8_0.d); let d = f32(block_q8_0.d);
@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q8_0) #endif
#decl(Q2_K) #ifdef Q2_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q2_K) #endif
#decl(Q3_K) #ifdef Q3_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q3_K) #endif
#decl(Q4_K) #ifdef Q4_K
// 8 blocks of 32 elements each // 8 blocks of 32 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q4_K) #endif
#decl(Q5_K) #ifdef Q5_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(Q5_K) #endif
#decl(Q6_K) #ifdef Q6_K
// 16 blocks of 16 elements each // 16 blocks of 16 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
sc_b_idx += 8; sc_b_idx += 8;
} }
} }
#endif
#enddecl(Q6_K) #ifdef IQ2_XXS
#decl(IQ2_XXS)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(IQ2_XXS) #endif
#decl(IQ2_XS) #ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(IQ2_XS) #endif
#decl(IQ2_S) #ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#endif
#enddecl(IQ2_S) #ifdef IQ3_XXS
#decl(IQ3_XSS)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(IQ3_XSS) #endif
#decl(IQ3_S) #ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#enddecl(IQ3_S) #endif
#decl(IQ1_S) #ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#endif
#enddecl(IQ1_S) #ifdef IQ1_M
#decl(IQ1_M)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
} }
} }
} }
#endif
#enddecl(IQ1_M) #ifdef IQ4_NL
#decl(IQ4_NL)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i++; dst_i++;
} }
} }
#enddecl(IQ4_NL) #endif
#decl(IQ4_XS) #ifdef IQ4_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset]; let block = src[src_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i += 16; dst_i += 16;
} }
} }
#enddecl(IQ4_XS) #endif
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>; var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1) @group(0) @binding(1)
var<storage, read_write> idx: array<i32>; var<storage, read_write> idx: array<i32>;
@group(0) @binding(2) @group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>; var<storage, read_write> dst: array<DST_TYPE>;
struct Params { struct Params {
offset_src: u32, // in elements offset_src: u32, // in elements
@ -842,8 +638,7 @@ struct Params {
@group(0) @binding(3) @group(0) @binding(3)
var<uniform> params: Params; var<uniform> params: Params;
override wg_size: u32; @compute @workgroup_size(WG_SIZE)
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) { fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) { if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return; return;
@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i); copy_elements(i_src_row, i_dst_row, i);
} }
} }
#end(SHADER)

View File

@ -1,195 +1,24 @@
#define(VARIANTS) enable f16;
[ #include "common_decls.tmpl"
{
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q8_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q2_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q3_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q6_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_m",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_nl",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#end(VARIANTS) #ifdef FLOAT
const BLOCK_SIZE = 1u;
#define(DECLS) #elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;
#decl(FLOAT) #elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif
#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
} }
#enddecl(FLOAT) #endif
#decl(Q4_0) #ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_0 = src0[src0_idx_base + offset]; let block_q4_0 = src0[src0_idx_base + offset];
let d = f32(block_q4_0.d); let d = f32(block_q4_0.d);
@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(Q4_0) #endif
#decl(Q4_1) #ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_1 = src0[src0_idx_base + offset]; let block_q4_1 = src0[src0_idx_base + offset];
let d = f32(block_q4_1.d); let d = f32(block_q4_1.d);
@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(Q4_1) #endif
#decl(Q5_0) #ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_0 = src0[src0_idx_base + offset]; let block_q5_0 = src0[src0_idx_base + offset];
let d = f32(block_q5_0.d); let d = f32(block_q5_0.d);
@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(Q5_0) #endif
#decl(Q5_1) #ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_1 = src0[src0_idx_base + offset]; let block_q5_1 = src0[src0_idx_base + offset];
let d = f32(block_q5_1.d); let d = f32(block_q5_1.d);
@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(Q5_1) #endif
#decl(Q8_0) #ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_0 = src0[src0_idx_base + offset]; let block_q8_0 = src0[src0_idx_base + offset];
let d = f32(block_q8_0.d); let d = f32(block_q8_0.d);
@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(Q8_0) #endif
#decl(Q8_1) #ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_1 = src0[src0_idx_base + offset]; let block_q8_1 = src0[src0_idx_base + offset];
let d = f32(block_q8_1.d); let d = f32(block_q8_1.d);
@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(Q8_1) #endif
#decl(Q2_K) #ifdef Q2_K
// 16 blocks of 16 elements each // 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(Q2_K) #ifdef Q3_K
#decl(Q3_K)
// 16 blocks of 16 elements each // 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(Q3_K) #ifdef Q4_K
#decl(Q4_K)
// 8 blocks of 32 elements each // 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(Q4_K) #ifdef Q5_K
#decl(Q5_K)
// 8 blocks of 32 elements each // 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(Q5_K) #ifdef Q6_K
#decl(Q6_K)
// 16 blocks of 16 elements each // 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(Q6_K) #ifdef IQ2_XXS
#decl(IQ2_XXS)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ2_XXS) #ifdef IQ2_XS
#decl(IQ2_XS)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ2_XS) #ifdef IQ2_S
#decl(IQ2_S)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#ifdef IQ3_XXS
#enddecl(IQ2_S)
#decl(IQ3_XSS)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ3_XSS) #ifdef IQ3_S
#decl(IQ3_S)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#enddecl(IQ3_S) #endif
#decl(IQ1_S) #ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ1_S)
#decl(IQ1_M) #ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ1_M) #ifdef IQ4_NL
#decl(IQ4_NL)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ4_NL) #ifdef IQ4_XS
#decl(IQ4_XS)
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset]; let block = src0[src0_idx_base + offset];
let d = f32(block.d); let d = f32(block.d);
@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
} }
return sum; return sum;
} }
#endif
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
struct MulMatParams { struct MulMatParams {
offset_src0: u32, // in elements/blocks offset_src0: u32, // in elements/blocks
@ -864,8 +672,8 @@ struct MulMatParams {
broadcast3: u32 broadcast3: u32
}; };
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns @group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams; @group(0) @binding(3) var<uniform> params: MulMatParams;
@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
var sum = 0.0; var sum = 0.0;
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i); sum += multiply_add(src0_idx_base, src1_idx_base, i);
} }
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
} }
#end(SHADER)

View File

@ -1,58 +1,65 @@
#decl(SHMEM_VEC) #ifdef VEC
#define VEC_SIZE 4
#define SHMEM_TYPE vec4<f16>
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn store_shmem(val: vec4<f16>, idx: u32) { fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x; shmem[idx] = val.x;
shmem[idx + 1] = val.y; shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z; shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w; shmem[idx + 3] = val.w;
} }
#enddecl(SHMEM_VEC) #endif
#ifdef SCALAR
#define VEC_SIZE 1
#define SHMEM_TYPE f16
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
#decl(SHMEM_SCALAR)
fn store_shmem(val: f16, idx: u32) { fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val; shmem[idx] = val;
} }
#enddecl(SHMEM_SCALAR) #endif
#decl(INIT_SRC0_SHMEM_FLOAT)
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_m = elem_idx / TILE_K; let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K; let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m; let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k; let global_k = k_outer + tile_k;
let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let src0_val = select( // taking a slight performance hit to avoid oob let src0_val = select( // taking a slight performance hit to avoid oob
{{SRC0_TYPE}}(0.0), SRC0_TYPE(0.0),
src0[src0_idx/{{VEC_SIZE}}], src0[src0_idx/VEC_SIZE],
global_m < params.m && global_k < params.k); global_m < params.m && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); store_shmem(SHMEM_TYPE(src0_val), elem_idx);
} }
} }
#endif
#enddecl(INIT_SRC0_SHMEM_FLOAT) #ifdef INIT_SRC1_SHMEM_FLOAT
#decl(INIT_SRC1_SHMEM)
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_n = elem_idx / TILE_K; let tile_n = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K; let tile_k = elem_idx % TILE_K;
let global_n = offset_n + tile_n; let global_n = offset_n + tile_n;
let global_k = k_outer + tile_k; let global_k = k_outer + tile_k;
let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
let src1_val = select( let src1_val = select(
{{SRC1_TYPE}}(0.0), SRC1_TYPE(0.0),
src1[src1_idx/{{VEC_SIZE}}], src1[src1_idx/VEC_SIZE],
global_n < params.n && global_k < params.k); global_n < params.n && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
} }
} }
#endif
#enddecl(INIT_SRC1_SHMEM) #ifdef INIT_SRC0_SHMEM_Q4_0
#decl(INIT_SRC0_SHMEM_Q4_0)
const BLOCK_SIZE = 32u; const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE; override BLOCKS_K = TILE_K/BLOCK_SIZE;
@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
} }
} }
} }
#endif
#enddecl(INIT_SRC0_SHMEM_Q4_0)

View File

@ -1,115 +1,19 @@
#define(VARIANTS) enable f16;
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
#end(VARIANTS) #include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS) #ifdef VEC
#decl(VEC)
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> { fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
} }
#enddecl(VEC) #endif
#decl(SCALAR) #ifdef SCALAR
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 { fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return f32(acc[tm][tn]); return f32(acc[tm][tn]);
} }
#enddecl(SCALAR) #endif
#end(DECLS)
#define(SHADER)
enable f16;
struct MulMatParams { struct MulMatParams {
offset_src0: u32, offset_src0: u32,
@ -130,14 +34,12 @@ struct MulMatParams {
broadcast3: u32 broadcast3: u32
}; };
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) @group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams; @group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
fn get_local_n(thread_id: u32) -> u32 { fn get_local_n(thread_id: u32) -> u32 {
return thread_id / WORKGROUP_SIZE_M; return thread_id / WORKGROUP_SIZE_M;
} }
@ -145,18 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 {
return thread_id % WORKGROUP_SIZE_M; return thread_id % WORKGROUP_SIZE_M;
} }
// TILE_M must be multiple of 4 for vec4 loads const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_M = {{WEBGPU_TILE_M}}u; const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_N = {{WEBGPU_TILE_N}}u; const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
override WORKGROUP_SIZE_M: u32;
override WORKGROUP_SIZE_N: u32;
override TILE_K: u32;
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>; var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
@ -233,15 +126,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var tn = 0u; tn < TILE_N; tn++) { for (var tn = 0u; tn < TILE_N; tn++) {
let global_col = output_col_base + tn; let global_col = output_col_base + tn;
if (global_col < params.n) { if (global_col < params.n) {
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
let global_row = output_row_base + tm; let global_row = output_row_base + tm;
if (global_row < params.m) { if (global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row; let dst_idx = dst_batch_offset + global_col * params.m + global_row;
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
} }
} }
} }
} }
} }
#end(SHADER)

View File

@ -1,100 +1,12 @@
#define(VARIANTS) diagnostic(off, chromium.subgroup_matrix_uniformity);
[ enable f16;
{ enable subgroups;
"SHADER_SUFFIX": "f32_f32_vec", enable chromium_experimental_subgroup_matrix;
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
#end(VARIANTS) #include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS) #ifdef VEC
#decl(VEC)
fn store_dst(shmem_idx: u32, dst_idx: u32) { fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = vec4<f32>( dst[dst_idx] = vec4<f32>(
f32(shmem[shmem_idx]), f32(shmem[shmem_idx]),
@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
f32(shmem[shmem_idx + 3]) f32(shmem[shmem_idx + 3])
); );
} }
#enddecl(VEC) #endif
#decl(SCALAR) #ifdef SCALAR
fn store_dst(shmem_idx: u32, dst_idx: u32) { fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = f32(shmem[shmem_idx]); dst[dst_idx] = f32(shmem[shmem_idx]);
} }
#enddecl(SCALAR) #endif
#end(DECLS)
#define(SHADER)
diagnostic(off, chromium.subgroup_matrix_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
struct MulMatParams { struct MulMatParams {
offset_src0: u32, offset_src0: u32,
@ -138,36 +42,19 @@ struct MulMatParams {
broadcast3: u32 broadcast3: u32
}; };
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns // SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams; @group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
// runtime subgroup size is smaller.
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
const TILE_K = {{WEBGPU_TILE_K}}u;
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
// runtime subgroup size is smaller.
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let local_row = idx % WG_TILE_STRIDE; let local_row = idx % WG_TILE_STRIDE;
let local_col = idx / WG_TILE_STRIDE; let local_col = idx / WG_TILE_STRIDE;
@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_col < params.n && global_row < params.m) { if (global_col < params.n && global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row; let dst_idx = dst_batch_offset + global_col * params.m + global_row;
store_dst(idx, dst_idx/{{VEC_SIZE}}); store_dst(idx, dst_idx/VEC_SIZE);
} }
} }
} }
#end(SHADER)

View File

@ -1,84 +1,17 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE": "vec4<f32>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE": "f32",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
}
]
#end(VARIANTS) enable f16;
#define(DECLS) #include "common_decls.tmpl"
#decl(VEC) #ifdef VEC
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); #define VEC_SIZE 4
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(dot(SRC1_TYPE(src0_val), src1_val));
} }
fn store_val(group_base: u32) -> vec4<f32> { fn store_val(group_base: u32) -> vec4<f32> {
@ -87,33 +20,37 @@ fn store_val(group_base: u32) -> vec4<f32> {
partial_sums[group_base + THREADS_PER_OUTPUT * 2], partial_sums[group_base + THREADS_PER_OUTPUT * 2],
partial_sums[group_base + THREADS_PER_OUTPUT * 3]); partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
} }
#enddecl(VEC) #endif
#decl(SCALAR) #ifdef SCALAR
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
#define VEC_SIZE 1
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(src0_val) * f32(src1_val); return f32(src0_val) * f32(src1_val);
} }
fn store_val(group_base: u32) -> f32 { fn store_val(group_base: u32) -> f32 {
return partial_sums[group_base]; return partial_sums[group_base];
} }
#enddecl(SCALAR) #endif
#decl(MUL_ACC_FLOAT)
#ifdef MUL_ACC_FLOAT
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0; var local_sum = 0.0;
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
let b = shared_vector[i / {{VEC_SIZE}}]; let b = shared_vector[i / VEC_SIZE];
local_sum += inner_dot(a, b); local_sum += inner_dot(a, b);
} }
return local_sum; return local_sum;
} }
#endif
#enddecl(MUL_ACC_FLOAT) #ifdef MUL_ACC_Q4_0
#decl(MUL_ACC_Q4_0)
const BLOCK_SIZE = 32; const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread const NQ = 16u; // number of weights per thread
@ -145,15 +82,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
} }
return local_sum; return local_sum;
} }
#endif
#enddecl(MUL_ACC_Q4_0)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
struct MulMatParams { struct MulMatParams {
offset_src0: u32, offset_src0: u32,
@ -174,22 +103,20 @@ struct MulMatParams {
broadcast3: u32 broadcast3: u32
}; };
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) // SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed) @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams; @group(0) @binding(3) var<uniform> params: MulMatParams;
override WORKGROUP_SIZE: u32; const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
override TILE_K: u32;
override OUTPUTS_PER_WG: u32;
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
// Shared memory for collaborative loading and reduction // Shared memory for collaborative loading and reduction
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction var<workgroup> partial_sums: array<f32, WG_SIZE>; // For reduction
@compute @workgroup_size(WORKGROUP_SIZE) @compute @workgroup_size(WG_SIZE)
fn main( fn main(
@builtin(local_invocation_id) local_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(workgroup_id) wg_id: vec3<u32>,
@ -232,8 +159,8 @@ fn main(
let tile_size = min(TILE_K, params.k - k_tile); let tile_size = min(TILE_K, params.k - k_tile);
// Cooperatively load vector tile into shared memory (all threads) // Cooperatively load vector tile into shared memory (all threads)
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
} }
workgroupBarrier(); workgroupBarrier();
@ -250,7 +177,7 @@ fn main(
workgroupBarrier(); workgroupBarrier();
let group_base = thread_group * THREADS_PER_OUTPUT; let group_base = thread_group * THREADS_PER_OUTPUT;
let thread_base = group_base + thread_in_group; let thread_base = group_base + thread_in_group;
var offset = THREADS_PER_OUTPUT / 2; var offset: u32 = THREADS_PER_OUTPUT / 2;
while (offset > 0) { while (offset > 0) {
if (thread_in_group < offset) { if (thread_in_group < offset) {
partial_sums[thread_base] += partial_sums[thread_base + offset]; partial_sums[thread_base] += partial_sums[thread_base + offset];
@ -260,8 +187,8 @@ fn main(
} }
// Store back to global memory // Store back to global memory
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); dst[dst_idx / VEC_SIZE] = store_val(group_base);
} }
} }
#end(SHADER)

View File

@ -1,21 +1,11 @@
#define(VARIANTS) #ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
[ fn store_scale(val: f32, offset: u32) {
{ src[offset] = val;
"SHADER_NAME": "scale_f32", }
"DECLS": ["NOT_INPLACE"] #else
},
{
"SHADER_NAME": "scale_f32_inplace",
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
@group(0) @binding(1) @group(0) @binding(1)
var<storage, read_write> dst: array<f32>; var<storage, read_write> dst: array<f32>;
@ -25,20 +15,7 @@ var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) { fn store_scale(val: f32, offset: u32) {
dst[offset] = val; dst[offset] = val;
} }
#enddecl(NOT_INPLACE) #endif
#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
fn store_scale(val: f32, offset: u32) {
src[offset] = val;
}
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
struct Params { struct Params {
offset_src: u32, offset_src: u32,
@ -65,10 +42,7 @@ struct Params {
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> src: array<f32>; var<storage, read_write> src: array<f32>;
DECLS @compute @workgroup_size(WG_SIZE)
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) { fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) { if (gid.x >= params.ne) {
return; return;
@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
store_scale(src[i_src] * params.scale + params.bias, i_dst); store_scale(src[i_src] * params.scale + params.bias, i_dst);
} }
#end(SHADER)