ggml : block interleaving support for Q4_K quanti for AArch64
* new quanti: block_q4_kx4 with offline repack impl * new quantize path: neon impl for ggml_quantize_mat_q8_K_4x8 * new gemv kernel: ggml_gemv_q4_K_4x8_q8_K based on dotprod * new gemm kernel: ggml_gemm_q4_K_4x8_q8_K based on i8mm * performance boost for both S_PP and S_TG --------- Co-authored-by: yuanjia111 <yuan.jia@sanechips.com.cn>
This commit is contained in:
parent
0a0bba05e8
commit
2f3dfe2e74
|
|
@ -38,6 +38,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -48,6 +49,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -58,7 +60,6 @@
|
|||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
|
|
@ -69,12 +70,14 @@
|
|||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -94,6 +97,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -104,6 +108,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -126,6 +131,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -136,6 +142,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -165,6 +172,7 @@
|
|||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -174,6 +182,7 @@
|
|||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -202,6 +211,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -212,6 +222,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -242,6 +253,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -252,6 +264,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
|
|||
|
|
@ -429,6 +429,121 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
const block_q4_Kx4 *GGML_RESTRICT q4 = (const block_q4_Kx4*) vx;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
for (int c = 0; c < nc; c += ncols_interleaved) {
|
||||
const block_q8_K *GGML_RESTRICT q8 = (const block_q8_K *) vy;
|
||||
float32x4_t res = vdupq_n_f32(0);
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->d)); // d0 d1 d2 d3
|
||||
float32x4_t q8_d = vdupq_n_f32(q8->d);
|
||||
float32x4_t g_d = vmulq_f32 (q4_d, q8_d);
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->dmin)); // dmin0 dmin1 dmin2 dmin3
|
||||
float32x4_t g_dmin = vmulq_f32(q4_dmin, q8_d);
|
||||
const uint8_t * GGML_RESTRICT q4_ptr = q4->qs;
|
||||
const int8_t * GGML_RESTRICT q8_ptr = q8->qs;
|
||||
int32x4_t prod = vdupq_n_s32(0);
|
||||
const int16x8_t q8_sums = vpaddq_s16(vld1q_s16(q8->bsums), vld1q_s16(q8->bsums + 8));
|
||||
// when using vgetq_lane_s16, its index must be a constant, which cannot be used in a loop, so use vst1q_s16 instead.
|
||||
int16_t tmp_arry[8];
|
||||
vst1q_s16(tmp_arry, q8_sums);
|
||||
for (int j = 0; j < QK_K / 32; ++j) {
|
||||
int32x4_t sum0 = vdupq_n_s32(0);
|
||||
int32x4_t sum1 = vdupq_n_s32(0);
|
||||
// each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3
|
||||
int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *)q4->scales + 8 * j)) ;
|
||||
prod = vmlal_s16(prod, vdup_n_s16(tmp_arry[j]), vget_high_s16(scales_mins));
|
||||
uint8x16_t q4_0 = vld1q_u8((const uint8_t *) q4_ptr);
|
||||
uint8x16_t q4_1 = vld1q_u8((const uint8_t *) q4_ptr + 16);
|
||||
uint8x16_t q4_2 = vld1q_u8((const uint8_t *) q4_ptr + 32);
|
||||
uint8x16_t q4_3 = vld1q_u8((const uint8_t *) q4_ptr + 48);
|
||||
q4_ptr += 64;
|
||||
int8x16_t q8_0 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr); // 8 个 8-bit
|
||||
int8x16_t q8_1 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 1);
|
||||
int8x16_t q8_2 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 2);
|
||||
int8x16_t q8_3 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 3);
|
||||
q8_ptr += 32;
|
||||
|
||||
/* low bits
|
||||
(1) sum0
|
||||
b0_000 b0_001 b0_002 b0_003 b0_004 b0_005 b0_006 b0_007 | b1_000 b1_001 b1_002 b1_003 b1_004 b1_005 b1_006 b1_007
|
||||
* a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(2) sum1
|
||||
b2_000 b2_001 b2_002 b2_003 b2_004 b2_005 b2_006 b2_007 | b3_000 b3_001 b3_002 b3_003 b3_004 b3_005 b3_006 b3_007
|
||||
* a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(3) sum0
|
||||
b0_008 b0_009 b0_010 b0_011 b0_012 b0_013 b0_014 b0_015 | b1_008 b1_009 b1_010 b1_011 b1_012 b1_013 b1_014 b1_015
|
||||
* a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(4) sum1
|
||||
b2_008 b2_009 b2_010 b2_011 b2_012 b2_013 b2_014 b2_015 | b3_008 b3_009 b3_010 b3_011 b3_012 b3_013 b3_014 b3_015
|
||||
* a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
*/
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_0, m4b)), q8_0);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_1, m4b)), q8_0);
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_2, m4b)), q8_1);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_3, m4b)), q8_1);
|
||||
|
||||
/* high bits
|
||||
(1) sum0
|
||||
b0_016 b0_017 b0_018 b0_019 b0_020 b0_021 b0_022 b0_023 | b1_016 b1_017 b1_018 b1_019 b1_020 b1_021 b1_022 b1_023
|
||||
* a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(2) sum1
|
||||
b2_016 b2_017 b2_018 b2_019 b2_020 b2_021 b2_022 b2_023 | b3_016 b3_017 b3_018 b3_019 b3_020 b3_021 b3_022 b3_023
|
||||
* a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(3) sum0
|
||||
b_024 b0_025 b0_026 b0_027 b0_028 b0_029 b0_030 b0_031 | b1_024 b1_025 b1_026 b1_027 b1_028 b1_029 b1_030 b1_031
|
||||
* a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31
|
||||
|------------dot------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(4) sum1
|
||||
b2_024 b2_025 b2_026 b2_027 b2_028 b2_029 b2_030 b2_031 | b3_024 b3_025 b3_026 b3_027 b3_028 b3_029 b3_030 b3_031
|
||||
* a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31
|
||||
|------------dot------------ | |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
*/
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_0, 4)), q8_2);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_1, 4)), q8_2);
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_2, 4)), q8_3);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_3, 4)), q8_3);
|
||||
float32x4_t sumf = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), vpaddq_s32(sum0, sum1)));
|
||||
res = vfmaq_f32(res, g_d, sumf);
|
||||
}
|
||||
res -= vmulq_f32(g_dmin, vcvtq_f32_s32(prod));
|
||||
q4++;
|
||||
q8++;
|
||||
}
|
||||
vst1q_f32(s, res);
|
||||
s += ncols_interleaved;
|
||||
}
|
||||
return;
|
||||
}
|
||||
#else
|
||||
// todo, c implementation
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -2253,6 +2368,298 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8; // c implementation will use
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr); // row
|
||||
UNUSED(nc); // column
|
||||
UNUSED(nb); // block
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr_start = (const block_q8_Kx4 *) vy;
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr_start = (const block_q4_Kx4 *) vx;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
float32x4_t zeros = vdupq_n_f32(0.0f);
|
||||
int anr = nr - nr % 16;
|
||||
int row = 0;
|
||||
// Row loop
|
||||
for (; row < anr / 4; row += 4) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptrs[4];
|
||||
q8_ptrs[0] = q8_ptr_start + (row * nb);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
q8_ptrs[i + 1] = q8_ptrs[i] + nb;
|
||||
}
|
||||
// Column loop
|
||||
for (int col = 0; col < nc / ncols_interleaved; col++) {
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb);
|
||||
// init output
|
||||
float32x4_t res[16]; // final result
|
||||
for (int i = 0; i < 16; i++) {
|
||||
res[i] = zeros;
|
||||
}
|
||||
// Block loop
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
float32x4_t g_d[16];
|
||||
float32x4_t g_dmin[16];
|
||||
int16x8_t q8_bsums[16];
|
||||
int32x4_t prod[16]; // store bsums*mins
|
||||
for (int i = 0; i < 16; i++) {
|
||||
g_d[i] = zeros;
|
||||
g_dmin[i] = zeros;
|
||||
q8_bsums[i] = vdupq_n_s16(0);
|
||||
prod[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Get global d and dmin
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3
|
||||
int16_t tmp_q8_bsums_array[16][8];
|
||||
for (int iter = 0; iter < 4; iter++) {
|
||||
// Calculation when four lines are grouped together
|
||||
for (int in = 0; in < 4; in++) {
|
||||
float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptrs[iter][b].d[in]);
|
||||
g_d[in + 4 * iter] = vmulq_f32(q4_d, scalar_q8_d);
|
||||
g_dmin[in + 4 * iter] = vmulq_f32(q4_dmin, scalar_q8_d);
|
||||
// The 16 elements in each row are merged into 8 elements. No loop expansion is performed here
|
||||
q8_bsums[in + 4 * iter] = vpaddq_s16(vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in), vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in + 8));
|
||||
vst1q_s16(tmp_q8_bsums_array[in + 4 * iter], q8_bsums[in + 4 * iter]);
|
||||
}
|
||||
}
|
||||
// The 256 elements in the superblock are processed in 8 steps
|
||||
for (int sb = 0; sb < QK_K / 32; sb++) {
|
||||
int32x4_t acc_rows[16]; // the calculated value of qs
|
||||
int32x4_t sum[16]; // the value of qs after rearranging
|
||||
for (int i = 0; i < 16; i++) {
|
||||
acc_rows[i] = vdupq_n_s32(0);
|
||||
sum[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3
|
||||
int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb));
|
||||
uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64);
|
||||
|
||||
int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7)
|
||||
int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7)
|
||||
int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15)
|
||||
int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15)
|
||||
|
||||
int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23)
|
||||
int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23)
|
||||
int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31)
|
||||
int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31)
|
||||
|
||||
// The 16 rows of the left matrix are expanded four times
|
||||
for (int iter = 0; iter < 4; iter++) {
|
||||
// Direct loop unrolling
|
||||
prod[0 + 4 * iter] = vmlal_s16(prod[0 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[0 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3)
|
||||
prod[1 + 4 * iter] = vmlal_s16(prod[1 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[1 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3)
|
||||
prod[2 + 4 * iter] = vmlal_s16(prod[2 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[2 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3)
|
||||
prod[3 + 4 * iter] = vmlal_s16(prod[3 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[3 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3)
|
||||
|
||||
int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 128 * sb); // A0(0-7) A1(0-7)
|
||||
int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7)
|
||||
|
||||
int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15)
|
||||
int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15)
|
||||
|
||||
int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23)
|
||||
int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23)
|
||||
|
||||
int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31)
|
||||
int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31)
|
||||
|
||||
// rearranging vectors
|
||||
sum[0 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[0 + 4 * iter]), vget_low_s32(acc_rows[1 + 4 * iter]));
|
||||
sum[1 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[0 + 4 * iter]), vget_high_s32(acc_rows[1 + 4 * iter]));
|
||||
sum[2 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[2 + 4 * iter]), vget_low_s32(acc_rows[3 + 4 * iter]));
|
||||
sum[3 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[2 + 4 * iter]), vget_high_s32(acc_rows[3 + 4 * iter]));
|
||||
|
||||
float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0 + 4 * iter])); // scales *qs
|
||||
float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1 + 4 * iter]));
|
||||
float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2 + 4 * iter]));
|
||||
float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3 + 4 * iter]));
|
||||
|
||||
res[0 + 4 * iter] = vfmaq_f32(res[0 + 4 * iter], g_d[0 + 4 * iter], sumf_0);
|
||||
res[1 + 4 * iter] = vfmaq_f32(res[1 + 4 * iter], g_d[1 + 4 * iter], sumf_1);
|
||||
res[2 + 4 * iter] = vfmaq_f32(res[2 + 4 * iter], g_d[2 + 4 * iter], sumf_2);
|
||||
res[3 + 4 * iter] = vfmaq_f32(res[3 + 4 * iter], g_d[3 + 4 * iter], sumf_3);
|
||||
}
|
||||
}
|
||||
for (int iter = 0; iter < 4; iter++) {
|
||||
res[0 + 4 * iter] -= vmulq_f32(g_dmin[0 + 4 * iter], vcvtq_f32_s32(prod[0 + 4 * iter]));
|
||||
res[1 + 4 * iter] -= vmulq_f32(g_dmin[1 + 4 * iter], vcvtq_f32_s32(prod[1 + 4 * iter]));
|
||||
res[2 + 4 * iter] -= vmulq_f32(g_dmin[2 + 4 * iter], vcvtq_f32_s32(prod[2 + 4 * iter]));
|
||||
res[3 + 4 * iter] -= vmulq_f32(g_dmin[3 + 4 * iter], vcvtq_f32_s32(prod[3 + 4 * iter]));
|
||||
}
|
||||
}
|
||||
// store result
|
||||
for (int i = 0; i < 16; i++) {
|
||||
vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handling tail parts that are less than 16 lines
|
||||
for (; row < nr / 4; row++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = q8_ptr_start + (row * nb);
|
||||
// Column loop
|
||||
for (int col = 0; col < nc / ncols_interleaved; col++) {
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb);
|
||||
// init output
|
||||
float32x4_t res[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res[i] = zeros;
|
||||
}
|
||||
// Block loop
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
float32x4_t g_d[4];
|
||||
float32x4_t g_dmin[4];
|
||||
int16x8_t q8_bsums[4];
|
||||
int32x4_t prod[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
g_d[i] = zeros;
|
||||
g_dmin[i] = zeros;
|
||||
q8_bsums[i] = vdupq_n_s16(0);
|
||||
prod[i] = vdupq_n_s32(0);
|
||||
}
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3
|
||||
int16_t tmp_q8_bsums_array[4][8];
|
||||
for (int in = 0; in < 4; in++) {
|
||||
float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptr[b].d[in]);
|
||||
g_d[in] = vmulq_f32(q4_d, scalar_q8_d);
|
||||
g_dmin[in] = vmulq_f32(q4_dmin, scalar_q8_d);
|
||||
q8_bsums[in] = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * in), vld1q_s16(q8_ptr[b].bsums + 16 * in + 8));
|
||||
vst1q_s16(tmp_q8_bsums_array[in], q8_bsums[in]);
|
||||
}
|
||||
for (int sb = 0; sb < QK_K / 32; sb++) {
|
||||
int32x4_t acc_rows[4];
|
||||
int32x4_t sum[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
acc_rows[i] = vdupq_n_s32(0);
|
||||
sum[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3
|
||||
int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb));
|
||||
uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64);
|
||||
|
||||
int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7)
|
||||
int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7)
|
||||
int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15)
|
||||
int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15)
|
||||
|
||||
prod[0] = vmlal_s16(prod[0], vdup_n_s16(tmp_q8_bsums_array[0][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3)
|
||||
prod[1] = vmlal_s16(prod[1], vdup_n_s16(tmp_q8_bsums_array[1][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3)
|
||||
prod[2] = vmlal_s16(prod[2], vdup_n_s16(tmp_q8_bsums_array[2][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3)
|
||||
prod[3] = vmlal_s16(prod[3], vdup_n_s16(tmp_q8_bsums_array[3][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3)
|
||||
|
||||
int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 128 * sb); // A0(0-7) A1(0-7)
|
||||
int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7)
|
||||
|
||||
int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15)
|
||||
int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15)
|
||||
|
||||
int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23)
|
||||
int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23)
|
||||
int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31)
|
||||
int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31)
|
||||
|
||||
int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23)
|
||||
int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23)
|
||||
|
||||
int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31)
|
||||
int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31)
|
||||
|
||||
// rearranging vectors
|
||||
sum[0] = vcombine_s32(vget_low_s32(acc_rows[0]), vget_low_s32(acc_rows[1]));
|
||||
sum[1] = vcombine_s32(vget_high_s32(acc_rows[0]), vget_high_s32(acc_rows[1]));
|
||||
sum[2] = vcombine_s32(vget_low_s32(acc_rows[2]), vget_low_s32(acc_rows[3]));
|
||||
sum[3] = vcombine_s32(vget_high_s32(acc_rows[2]), vget_high_s32(acc_rows[3]));
|
||||
|
||||
float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0])); // scales *qs
|
||||
float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1]));
|
||||
float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2]));
|
||||
float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3]));
|
||||
|
||||
res[0] = vfmaq_f32(res[0], g_d[0], sumf_0);
|
||||
res[1] = vfmaq_f32(res[1], g_d[1], sumf_1);
|
||||
res[2] = vfmaq_f32(res[2], g_d[2], sumf_2);
|
||||
res[3] = vfmaq_f32(res[3], g_d[3], sumf_3);
|
||||
}
|
||||
res[0] -= vmulq_f32(g_dmin[0], vcvtq_f32_s32(prod[0]));
|
||||
res[1] -= vmulq_f32(g_dmin[1], vcvtq_f32_s32(prod[1]));
|
||||
res[2] -= vmulq_f32(g_dmin[2], vcvtq_f32_s32(prod[2]));
|
||||
res[3] -= vmulq_f32(g_dmin[3], vcvtq_f32_s32(prod[3]));
|
||||
}
|
||||
// store result
|
||||
for (int i = 0; i < 4; i++) {
|
||||
vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#else
|
||||
// todo, c implementation
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
|
|||
|
|
@ -287,230 +287,6 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
float iscale[4];
|
||||
__m256 srcv[4][32];
|
||||
__m256 iscale_vec[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
|
||||
srcv[row_iter][0] = v0;
|
||||
srcv[row_iter][1] = v1;
|
||||
srcv[row_iter][2] = v2;
|
||||
srcv[row_iter][3] = v3;
|
||||
|
||||
for (int sb = 1; sb < 8; sb++) {
|
||||
// Temporarily stores absolute quant values
|
||||
__m256 tempAbs = maxAbs;
|
||||
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs0 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
|
||||
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
|
||||
|
||||
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
|
||||
|
||||
srcv[row_iter][sb * 4] = v0;
|
||||
srcv[row_iter][sb * 4 + 1] = v1;
|
||||
srcv[row_iter][sb * 4 + 2] = v2;
|
||||
srcv[row_iter][sb * 4 + 3] = v3;
|
||||
}
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
|
||||
|
||||
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
|
||||
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
|
||||
|
||||
const int mask = _mm256_movemask_ps(finalMask);
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
||||
|
||||
if(mask) {
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
|
||||
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
|
||||
}
|
||||
|
||||
__m256i quants_interleaved[32];
|
||||
for (int j = 0; j < 32; j++) {
|
||||
// Apply the multiplier
|
||||
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
|
||||
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
|
||||
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
|
||||
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 );
|
||||
i2 = _mm256_packs_epi32( i2, i3 );
|
||||
// Convert int16 to int8
|
||||
i0 = _mm256_packs_epi16( i0, i2 );
|
||||
|
||||
// Permute and store the quantized weights in the required order after the pack instruction
|
||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||
|
||||
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
|
||||
quants_interleaved[j] = i0;
|
||||
}
|
||||
|
||||
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
|
||||
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
|
||||
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
|
||||
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
|
||||
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
|
||||
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
|
||||
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
// Quants from four different sub blocks are taken
|
||||
__m256i q0 = quants_interleaved[k * 8 + 0];
|
||||
__m256i q1 = quants_interleaved[k * 8 + 1];
|
||||
__m256i q2 = quants_interleaved[k * 8 + 2];
|
||||
__m256i q3 = quants_interleaved[k * 8 + 3];
|
||||
__m256i q4 = quants_interleaved[k * 8 + 4];
|
||||
__m256i q5 = quants_interleaved[k * 8 + 5];
|
||||
__m256i q6 = quants_interleaved[k * 8 + 6];
|
||||
__m256i q7 = quants_interleaved[k * 8 + 7];
|
||||
|
||||
|
||||
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
__m256i one = _mm256_set1_epi8(1);
|
||||
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q0 = _mm256_srli_epi64(q0, 16);
|
||||
q2 = _mm256_srli_epi64(q2, 16);
|
||||
q4 = _mm256_srli_epi64(q4, 16);
|
||||
q6 = _mm256_srli_epi64(q6, 16);
|
||||
|
||||
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
|
||||
}
|
||||
|
||||
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q1 = _mm256_srli_epi64(q1, 16);
|
||||
q3 = _mm256_srli_epi64(q3, 16);
|
||||
q5 = _mm256_srli_epi64(q5, 16);
|
||||
q7 = _mm256_srli_epi64(q7, 16);
|
||||
|
||||
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
|
||||
}
|
||||
|
||||
// Overall bsums in interleaved fashion computed by adding results of both halves
|
||||
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
|
||||
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// GEMV/GEMM templates
|
||||
//
|
||||
|
|
|
|||
|
|
@ -227,6 +227,346 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t asrcv[64];
|
||||
float32x4_t amaxv[64];
|
||||
|
||||
// d:
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
// Check if exists: orig == amax
|
||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||
for (int j = 0; j < 64; j++) {
|
||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||
}
|
||||
|
||||
// Assume that none == amax, then check mask_all to reverse
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||
if (vmaxvq_u32(cmp) != 0) {
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||
// loops = 16
|
||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||
// Process row0 and row1
|
||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||
|
||||
// Process row2 and row3
|
||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||
}
|
||||
}
|
||||
#elif defined(__AVX2__)
|
||||
float iscale[4];
|
||||
__m256 srcv[4][32];
|
||||
__m256 iscale_vec[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
|
||||
srcv[row_iter][0] = v0;
|
||||
srcv[row_iter][1] = v1;
|
||||
srcv[row_iter][2] = v2;
|
||||
srcv[row_iter][3] = v3;
|
||||
|
||||
for (int sb = 1; sb < 8; sb++) {
|
||||
// Temporarily stores absolute quant values
|
||||
__m256 tempAbs = maxAbs;
|
||||
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs0 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
|
||||
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
|
||||
|
||||
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
|
||||
|
||||
srcv[row_iter][sb * 4] = v0;
|
||||
srcv[row_iter][sb * 4 + 1] = v1;
|
||||
srcv[row_iter][sb * 4 + 2] = v2;
|
||||
srcv[row_iter][sb * 4 + 3] = v3;
|
||||
}
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
|
||||
|
||||
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
|
||||
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
|
||||
|
||||
const int mask = _mm256_movemask_ps(finalMask);
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
||||
|
||||
if(mask) {
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
|
||||
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
|
||||
}
|
||||
|
||||
__m256i quants_interleaved[32];
|
||||
for (int j = 0; j < 32; j++) {
|
||||
// Apply the multiplier
|
||||
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
|
||||
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
|
||||
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
|
||||
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 );
|
||||
i2 = _mm256_packs_epi32( i2, i3 );
|
||||
// Convert int16 to int8
|
||||
i0 = _mm256_packs_epi16( i0, i2 );
|
||||
|
||||
// Permute and store the quantized weights in the required order after the pack instruction
|
||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||
|
||||
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
|
||||
quants_interleaved[j] = i0;
|
||||
}
|
||||
|
||||
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
|
||||
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
|
||||
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
|
||||
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
|
||||
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
|
||||
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
|
||||
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
// Quants from four different sub blocks are taken
|
||||
__m256i q0 = quants_interleaved[k * 8 + 0];
|
||||
__m256i q1 = quants_interleaved[k * 8 + 1];
|
||||
__m256i q2 = quants_interleaved[k * 8 + 2];
|
||||
__m256i q3 = quants_interleaved[k * 8 + 3];
|
||||
__m256i q4 = quants_interleaved[k * 8 + 4];
|
||||
__m256i q5 = quants_interleaved[k * 8 + 5];
|
||||
__m256i q6 = quants_interleaved[k * 8 + 6];
|
||||
__m256i q7 = quants_interleaved[k * 8 + 7];
|
||||
|
||||
|
||||
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
__m256i one = _mm256_set1_epi8(1);
|
||||
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q0 = _mm256_srli_epi64(q0, 16);
|
||||
q2 = _mm256_srli_epi64(q2, 16);
|
||||
q4 = _mm256_srli_epi64(q4, 16);
|
||||
q6 = _mm256_srli_epi64(q6, 16);
|
||||
|
||||
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
|
||||
}
|
||||
|
||||
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q1 = _mm256_srli_epi64(q1, 16);
|
||||
q3 = _mm256_srli_epi64(q3, 16);
|
||||
q5 = _mm256_srli_epi64(q5, 16);
|
||||
q7 = _mm256_srli_epi64(q7, 16);
|
||||
|
||||
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
|
||||
}
|
||||
|
||||
// Overall bsums in interleaved fashion computed by adding results of both halves
|
||||
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
|
||||
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// NOTE: This default c implementation is aligned with AVX2 implemantation, but differs from arm implementation.
|
||||
// especially in bsums arrangement.
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
||||
|
|
@ -1505,6 +1845,90 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|||
return out;
|
||||
}
|
||||
|
||||
static void make_block_q4_Kx4(block_q4_K * in, unsigned int blck_size_interleave, block_q4_Kx4 * out) {
|
||||
int nrow = 4;
|
||||
int nloop = 4;
|
||||
|
||||
// d and dmin values of the 4 Q4_K are copied directly.
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
out->d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
out->dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||
}
|
||||
|
||||
// For qs, 2 things need to be done:
|
||||
// 1. Recover from Q4_K storage tyle to Q4_0 style
|
||||
// 2. Interleave quants by taking 8 bytes at a time
|
||||
|
||||
// 1.
|
||||
const uint64_t lo_mask = 0x0f0f0f0f0f0f0f0fULL;
|
||||
const uint64_t hi_mask = 0xf0f0f0f0f0f0f0f0ULL;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
uint64_t *q = (uint64_t *)(in[i].qs);
|
||||
for (int j = 0; j < nloop; j++) {
|
||||
uint64_t q0, q1, q2, q3;
|
||||
q0 = q[0];
|
||||
q1 = q[1];
|
||||
q2 = q[2];
|
||||
q3 = q[3];
|
||||
|
||||
uint64_t hi1, hi2, lo3, lo4;
|
||||
hi1 = q0 & hi_mask;
|
||||
hi2 = q1 & hi_mask;
|
||||
lo3 = q2 & lo_mask;
|
||||
lo4 = q3 & lo_mask;
|
||||
q[0] = (q0 & lo_mask) | (lo3 << 4);
|
||||
q[1] = (q1 & lo_mask) | (lo4 << 4);
|
||||
q[2] = (q2 & hi_mask) | (hi1 >> 4);
|
||||
q[3] = (q3 & hi_mask) | (hi2 >> 4);
|
||||
|
||||
q += 4;
|
||||
}
|
||||
}
|
||||
|
||||
// 2.
|
||||
// Calculate total number of interleaved subblocks
|
||||
const int end = QK_K * 2 / blck_size_interleave;
|
||||
uint64_t *src, *dst;
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 4;
|
||||
int src_offset = (i / 4) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
src = (uint64_t *)(&in[src_id].qs[src_offset]);
|
||||
dst = (uint64_t *)(&out->qs[dst_offset]);
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
// For scales & mins of each subblock. (8 subblocks in one Q4_K, 32 in total)
|
||||
// A special requirement to meet: expand to 8-bit from 6-bit.
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
uint32_t utmp[4];
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
// rearrange as d|d|...|d|min|min|...|min
|
||||
// expand to 8-bit from 6-bit
|
||||
memset(utmp, 0, 16);
|
||||
memcpy(utmp, in[i].scales, 12);
|
||||
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux = utmp[1] & kmask1;
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
// move to Q4_K
|
||||
const uint8_t * d_ptr = (const uint8_t*)&utmp[0];
|
||||
const uint8_t * m_ptr = (const uint8_t*)&utmp[2];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
out->scales[j * 8 + i] = *d_ptr++;
|
||||
out->scales[j * 8 + i + nrow] = *m_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
|
||||
block_q4_Kx8 out;
|
||||
//Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
|
||||
|
|
@ -1656,6 +2080,46 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
|||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q4_K_to_q4_K_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 4;
|
||||
|
||||
block_q4_Kx4 * dst = (block_q4_Kx4 *)t->data;
|
||||
const block_q4_K * src = (const block_q4_K *) data;
|
||||
block_q4_K dst_tmp[4];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++ ) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
make_block_q4_Kx4(dst_tmp, interleave_block, dst++);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
|
||||
// change tensor shape as block_q4_kx4 brings space size change
|
||||
//t->nb[0] = ggml_type_size(type);
|
||||
t->nb[0] = sizeof(block_q4_Kx4) / 4;
|
||||
t->nb[1] = t->nb[0] * (t->ne[0] / ggml_blck_size(t->type));
|
||||
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
||||
t->nb[i] = t->nb[i - 1] * t->ne[i - 1];
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
||||
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
|
||||
|
|
@ -1924,6 +2388,10 @@ template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * da
|
|||
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q4_K, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q4_K_to_q4_K_4_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
|
@ -1973,6 +2441,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 8, 4, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2021,6 +2493,10 @@ template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 4, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2292,7 +2768,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
//GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
|
|
@ -2429,6 +2905,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||
|
||||
// instance for Q4_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 4, GGML_TYPE_Q8_K> q4_K_4x8_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
|
|
@ -2466,6 +2943,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { // new for ARM N2
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &q4_K_4x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q4_K_8x8_q8_K;
|
||||
|
|
@ -2555,6 +3037,26 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
// size calculation after q4_kx4 repacking, it's different from traditional type
|
||||
size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes;
|
||||
const size_t blck_size = 256;
|
||||
const size_t type_size = sizeof(block_q4_Kx4) / 4;
|
||||
nbytes = ((tensor->ne[0] * type_size) / blck_size) * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
|
||||
return nbytes;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cpu_aarch64_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
if (tensor->type == GGML_TYPE_Q4_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
return ggml_nbytes_q4_kx4(tensor);
|
||||
}
|
||||
}
|
||||
return ggml_nbytes(tensor);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::repack {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
|
|
@ -2611,7 +3113,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {
|
|||
/* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_aarch64_buffer_type_get_alloc_size, // defaults to ggml_nbytes except for ARM N2
|
||||
/* .is_host = */ nullptr,
|
||||
},
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||
|
|
|
|||
|
|
@ -36,6 +36,40 @@ using block_q4_0x8 = block<4, 8>;
|
|||
using block_q8_0x4 = block<8, 4>;
|
||||
using block_q8_0x8 = block<8, 8>;
|
||||
|
||||
struct block_q4_Kx4 {
|
||||
ggml_half d[4]; // super-block scale for quantized scales
|
||||
ggml_half dmin[4]; // super-block scale for quantized mins
|
||||
int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack) TODO: consider if uint8_t?
|
||||
uint8_t qs[512]; // 4--bit quants
|
||||
|
||||
/********************************************************layout***************************************************************/
|
||||
// low <-------------------------------------------------------------------------------------> high
|
||||
//
|
||||
// d: |s0|s1|s2|s3|
|
||||
//
|
||||
// dmin: |s0|s1|s2|s3|
|
||||
//
|
||||
// scales: |-------- d --------|-------- m --------|
|
||||
// |s0b0|s1b0|s2b0|s3b0|s0b0|s1b0|s2b0|s3b0|
|
||||
// |s0b1|s1b1|s2b1|s3b1|s0b1|s1b1|s2b1|s3b1|
|
||||
// ......
|
||||
// |s0b7|s1b7|s2b7|s3b7|s0b7|s1b7|s2b7|s3b7|
|
||||
//
|
||||
// qs: <from block0 of all>
|
||||
// |s0w0 |s0w16|s0w1 |s0w17|s0w2 |s0w18|s0w3 |s0w19|s0w4 |s0w20|s0w5 |s0w21|s0w6 |s0w22|s0w7 |s0w23| --- 8B from s0
|
||||
// |s1w0 |s1w16|s1w1 |s1w17|s1w2 |s1w18|s1w3 |s1w19|s1w4 |s1w20|s1w5 |s1w21|s1w6 |s1w22|s1w7 |s1w23| --- 8B from s1
|
||||
// |s2w0 |s2w16|s2w1 |s2w17|s2w2 |s2w18|s2w3 |s2w19|s2w4 |s2w20|s2w5 |s2w21|s2w6 |s2w22|s2w7 |s2w23| --- 8B from s2
|
||||
// |s3w0 |s3w16|s3w1 |s3w17|s3w2 |s3w18|s3w3 |s3w19|s3w4 |s3w20|s3w5 |s3w21|s3w6 |s3w22|s3w7 |s3w23| --- 8B from s3
|
||||
// |s0w8 |s0w24|s0w9 |s0w25|s0w10|s0w26|s0w11|s0w27|s0w12|s0w28|s0w13|s0w29|s0w14|s0w30|s0w15|s0w31| --- 8B from s0
|
||||
// |s1w8 |s1w24|s1w9 |s1w25|s1w10|s1w26|s1w11|s1w27|s1w12|s1w28|s1w13|s1w29|s1w14|s1w30|s1w15|s1w31| --- 8B from s1
|
||||
// |s2w8 |s2w24|s2w9 |s2w25|s2w10|s2w26|s2w11|s2w27|s2w12|s2w28|s2w13|s2w29|s2w14|s2w30|s2w15|s2w31| --- 8B from s2
|
||||
// |s3w8 |s3w24|s3w9 |s3w25|s3w10|s3w26|s3w11|s3w27|s3w12|s3w28|s3w13|s3w29|s3w14|s3w30|s3w15|s3w31| --- 8B from s3
|
||||
// <from block1 of all>
|
||||
// ......
|
||||
// <from block7 off all>
|
||||
/*****************************************************************************************************************************/
|
||||
};
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
|
|
@ -85,6 +119,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -93,6 +128,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -108,17 +144,21 @@ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GG
|
|||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
// gemv_generic ???
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
// gemm_generic ???
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
|
|||
Loading…
Reference in New Issue