vulkan: faster q6_k matmul (#17813)
* q6_k faster mul mat * 8 values * fix comment * switch to two at a time * start ci for .glsl files
This commit is contained in:
parent
77ad8542bd
commit
d15d177f43
|
|
@ -20,7 +20,8 @@ on:
|
||||||
'**/*.swift',
|
'**/*.swift',
|
||||||
'**/*.m',
|
'**/*.m',
|
||||||
'**/*.metal',
|
'**/*.metal',
|
||||||
'**/*.comp'
|
'**/*.comp',
|
||||||
|
'**/*.glsl'
|
||||||
]
|
]
|
||||||
|
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
@ -40,7 +41,8 @@ on:
|
||||||
'**/*.swift',
|
'**/*.swift',
|
||||||
'**/*.m',
|
'**/*.m',
|
||||||
'**/*.metal',
|
'**/*.metal',
|
||||||
'**/*.comp'
|
'**/*.comp',
|
||||||
|
'**/*.glsl'
|
||||||
]
|
]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@
|
||||||
/out/
|
/out/
|
||||||
/tmp/
|
/tmp/
|
||||||
/autogen-*.md
|
/autogen-*.md
|
||||||
|
/common/build-info.cpp
|
||||||
|
|
||||||
# Deprecated
|
# Deprecated
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -244,17 +244,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||||
const uint iqs = idx % 128; // 0..127
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
const uint n = iqs / 64; // 0,1
|
const uint n = iqs / 64; // 0,1
|
||||||
const uint b = (iqs % 64) / 32; // 0,1
|
const uint b = ((iqs % 64) / 32) * 4; // 0,4
|
||||||
const uint is_b = (iqs % 16) / 8; // 0,1
|
const uint is_b = (iqs % 16) / 8; // 0,1
|
||||||
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||||
const uint is = 8 * n + qhshift + is_b; // 0..15
|
const uint is = 8 * n + qhshift + is_b; // 0..15
|
||||||
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
const uint qsi = n * 32 + (iqs % 32); // 0..63
|
||||||
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
const uint qhi = n * 16 + (iqs % 16); // 0..31
|
||||||
|
|
||||||
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||||
|
|
||||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
|
const uint ql = (uint(data_a_packed16[ib].ql[qsi]) >> b) & 0x0F0F;
|
||||||
dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303;
|
||||||
|
const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale;
|
||||||
|
|
||||||
|
buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y);
|
||||||
#elif defined(DATA_A_IQ1_S)
|
#elif defined(DATA_A_IQ1_S)
|
||||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue