1733 lines
97 KiB
C++
1733 lines
97 KiB
C++
#define GGML_COMMON_IMPL_CPP
|
|
#define GGML_COMMON_DECL_CPP
|
|
#include "ggml-common.h"
|
|
#include "ggml-backend-impl.h"
|
|
|
|
#include "ggml-impl.h"
|
|
#include "ggml-cpu.h"
|
|
#include "ggml-cpu-impl.h"
|
|
#include "simd-mappings.h"
|
|
#include "traits.h"
|
|
|
|
#include <cmath>
|
|
#include <cstring>
|
|
#include <cassert>
|
|
#include <cstdlib> // for qsort
|
|
#include <cstdio> // for GGML_ASSERT
|
|
|
|
#define GGML_CPU_CLANG_WORKAROUND
|
|
#include "../../repack.h"
|
|
|
|
#if defined(__GNUC__)
|
|
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
|
#endif
|
|
|
|
#define UNUSED GGML_UNUSED
|
|
|
|
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
assert(QK8_0 == 32);
|
|
assert(k % QK8_0 == 0);
|
|
const int nb = k / QK8_0;
|
|
|
|
#if defined(__riscv_v_intrinsic)
|
|
block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0);
|
|
const size_t vl_save = __riscv_vsetvl_e64m2(4);
|
|
vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1));
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
const float *x_block_base = x + i * QK8_0;
|
|
vint8m2_t q_r0, q_r1, q_r2, q_r3;
|
|
{
|
|
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc);
|
|
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
|
|
float d = amax / 127.0f;
|
|
y[i].d[0] = GGML_CPU_FP32_TO_FP16(d);
|
|
|
|
float id = d ? 1.0f / d : 0.0f;
|
|
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
}
|
|
asm volatile ("" ::: "memory");
|
|
|
|
{
|
|
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc);
|
|
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
|
|
float d = amax / 127.0f;
|
|
y[i].d[1] = GGML_CPU_FP32_TO_FP16(d);
|
|
float id = d ? 1.0f / d : 0.0f;
|
|
|
|
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
}
|
|
asm volatile ("" ::: "memory");
|
|
{
|
|
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc);
|
|
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
|
|
float d = amax / 127.0f;
|
|
y[i].d[2] = GGML_CPU_FP32_TO_FP16(d);
|
|
float id = d ? 1.0f / d : 0.0f;
|
|
|
|
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
}
|
|
asm volatile ("" ::: "memory");
|
|
{
|
|
vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc);
|
|
vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
|
|
vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
|
|
float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
|
|
|
|
float d = amax / 127.0f;
|
|
y[i].d[3] = GGML_CPU_FP32_TO_FP16(d);
|
|
float id = d ? 1.0f / d : 0.0f;
|
|
|
|
vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
|
|
vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
|
|
q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
|
|
}
|
|
vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0);
|
|
vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1);
|
|
vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2);
|
|
vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3);
|
|
vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3);
|
|
__riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save);
|
|
}
|
|
#else
|
|
UNUSED(nb);
|
|
UNUSED(y);
|
|
ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
|
|
#endif
|
|
}
|
|
|
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 8;
|
|
const int blocklen = 8;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v
|
|
if (__riscv_vlenb() >= QK4_0) {
|
|
const size_t vl = QK4_0;
|
|
|
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
|
|
|
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
for (int l = 0; l < nb; l++) {
|
|
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
|
|
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
|
|
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
|
|
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
|
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment constraints
|
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
|
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
|
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
|
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
|
|
|
|
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
|
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
|
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
|
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
|
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
|
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
|
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
|
|
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
|
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
|
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
|
|
// vector version needs Zvfhmin extension
|
|
const float a_scale = GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
|
const float b_scales[8] = {
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7])
|
|
};
|
|
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
|
sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
|
|
}
|
|
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
|
}
|
|
return;
|
|
}
|
|
|
|
#endif
|
|
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
|
|
|
|
// 1x16 Accumulator
|
|
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
// 1x16 Integer Accumulator
|
|
vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
|
|
// Accumulation loop.
|
|
for (int i = 0; i < QK4_0 / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);
|
|
const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);
|
|
|
|
sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16);
|
|
sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16);
|
|
}
|
|
|
|
const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);
|
|
|
|
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
|
|
|
|
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK_K;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
|
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
|
|
|
|
// 1x16 Accumulator
|
|
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16);
|
|
|
|
// We process 4 sub-blocks at once.
|
|
for (int j = 0; j < QK_K / 128; j++) {
|
|
// Extract the scales and the mins.
|
|
//
|
|
// Low bits.
|
|
vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);
|
|
vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);
|
|
vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);
|
|
|
|
// High bits.
|
|
vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);
|
|
vuint8m2_t scales_hi;
|
|
vuint8m2_t mins_hi;
|
|
if (!j) {
|
|
scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);
|
|
mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);
|
|
} else {
|
|
scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);
|
|
mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);
|
|
}
|
|
vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);
|
|
vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));
|
|
|
|
// Reduce the mins and multiply with `dmin`.
|
|
//
|
|
// Correct in `sumf`.
|
|
vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16);
|
|
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
|
|
const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
|
|
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16);
|
|
sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16);
|
|
|
|
// Accumulation for 2 sub-blocks.
|
|
//
|
|
// This might overflow, so we accumulate in two steps.
|
|
//
|
|
// Recheck.
|
|
for (int k = 0; k < 2; k++) {
|
|
vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
|
|
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);
|
|
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
|
|
sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16);
|
|
sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16);
|
|
}
|
|
|
|
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
sumi_s_0_16, 16);
|
|
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
sumi_s_1_16, 16);
|
|
}
|
|
// Accumulation for 2 sub-blocks.
|
|
//
|
|
// This might overflow, so we accumulate in two steps.
|
|
//
|
|
// Recheck.
|
|
for (int k = 0; k < 2; k++) {
|
|
vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
|
|
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);
|
|
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
|
|
sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16);
|
|
sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16);
|
|
}
|
|
|
|
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
sumi_s_0_16, 16);
|
|
sumi = __riscv_vwmacc_vv_i32m2(sumi,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
sumi_s_1_16, 16);
|
|
}
|
|
}
|
|
|
|
const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16);
|
|
|
|
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
|
|
|
|
// 1x16 Accumulator1
|
|
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
// 1x16 integer accumulator
|
|
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
|
|
// Accumulation loop.
|
|
for (int i = 0; i < QK4_NL / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
|
|
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
|
|
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
|
|
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
|
|
|
|
const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16);
|
|
const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16);
|
|
sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16);
|
|
}
|
|
|
|
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
|
|
|
|
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
UNUSED(bs);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
|
|
|
|
// 1x16 Accumulator
|
|
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
// 1x16 Integer Accumulator
|
|
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
|
|
// Accumulation loop.
|
|
for (int i = 0; i < QK8_0; i++) {
|
|
// Load `b_ptr`.
|
|
const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
// const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);
|
|
|
|
sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16);
|
|
}
|
|
|
|
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
|
|
|
|
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
assert(n % QK_K == 0);
|
|
assert(nr == 1);
|
|
assert(nc % 16 == 0);
|
|
|
|
UNUSED(bs);
|
|
|
|
const int N_COLS_TILE = 16;
|
|
const int num_k_blocks = n / QK_K;
|
|
|
|
const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);
|
|
for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {
|
|
|
|
const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy;
|
|
const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;
|
|
|
|
vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
|
|
for (int k_block = 0; k_block < num_k_blocks; ++k_block) {
|
|
const block_q8_K* lhs_current = &lhs_base_ptr[k_block];
|
|
const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];
|
|
|
|
// 1. Prepare Global Min Scales
|
|
vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);
|
|
vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);
|
|
|
|
vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl);
|
|
|
|
vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl);
|
|
|
|
const uint8_t* rhs_qs_ptr = rhs_current->qs;
|
|
const uint8_t* rhs_sc_ptr = rhs_current->scales;
|
|
const int8_t* lhs_qs_ptr = lhs_current->qs;
|
|
|
|
// --- Phase Loop (4 phases x 64 elements) ---
|
|
for (int phase = 0; phase < 4; ++phase) {
|
|
|
|
// A. Load Scales/Mins
|
|
vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;
|
|
vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;
|
|
|
|
{
|
|
vuint8mf2_t v_raw;
|
|
// Sub-block 0
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);
|
|
v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
// Sub-block 1
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);
|
|
v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
// Sub-block 2
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);
|
|
v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
// Sub-block 3
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);
|
|
v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
rhs_sc_ptr += 64;
|
|
}
|
|
|
|
int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);
|
|
int k_offsets[4] = {0, 32, 64, 96};
|
|
|
|
// B. Inner Dot Product Loop
|
|
for (int l = 0; l < 16; ++l) {
|
|
vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);
|
|
rhs_qs_ptr += 16;
|
|
|
|
// Sub-block 0
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);
|
|
|
|
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l];
|
|
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
}
|
|
// Sub-block 1
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);
|
|
|
|
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l];
|
|
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
}
|
|
// Sub-block 2
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);
|
|
|
|
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l];
|
|
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
}
|
|
// Sub-block 3
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);
|
|
|
|
int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l];
|
|
v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
|
|
}
|
|
}
|
|
|
|
// correction
|
|
int sb_base_abs = base_k_phase / 16;
|
|
|
|
// Sub-block 0
|
|
{
|
|
int sb_idx = sb_base_abs + (k_offsets[0] / 16);
|
|
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
}
|
|
// Sub-block 1
|
|
{
|
|
int sb_idx = sb_base_abs + (k_offsets[1] / 16);
|
|
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
}
|
|
// Sub-block 2
|
|
{
|
|
int sb_idx = sb_base_abs + (k_offsets[2] / 16);
|
|
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
}
|
|
// Sub-block 3
|
|
{
|
|
int sb_idx = sb_base_abs + (k_offsets[3] / 16);
|
|
int16_t bsum = lhs_current->bsums[sb_idx];
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
|
|
v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
|
|
}
|
|
|
|
} // End Phase Loop
|
|
|
|
// Apply global Scales
|
|
vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);
|
|
vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);
|
|
|
|
vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl);
|
|
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl);
|
|
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl);
|
|
v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl);
|
|
|
|
} // End K-Block
|
|
__riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl);
|
|
|
|
}
|
|
}
|
|
|
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 8;
|
|
const int blocklen = 8;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nr % 4 == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v
|
|
if (__riscv_vlenb() >= QK4_0) {
|
|
const size_t vl = QK4_0;
|
|
|
|
for (int y = 0; y < nr / 4; y++) {
|
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
|
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
for (int l = 0; l < nb; l++) {
|
|
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
|
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
|
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
|
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
|
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
|
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
|
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
|
|
|
// vector version needs Zvfhmin extension
|
|
const float a_scales[4] = {
|
|
GGML_CPU_FP16_TO_FP32(a_ptr[l].d[0]),
|
|
GGML_CPU_FP16_TO_FP32(a_ptr[l].d[1]),
|
|
GGML_CPU_FP16_TO_FP32(a_ptr[l].d[2]),
|
|
GGML_CPU_FP16_TO_FP32(a_ptr[l].d[3])
|
|
};
|
|
const float b_scales[8] = {
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]),
|
|
GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7])
|
|
};
|
|
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
|
|
|
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
|
|
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
|
|
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
|
|
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
|
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
vint16m4_t sumi_l0;
|
|
{
|
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
|
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
|
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
|
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
|
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
|
|
sumi_l0 = sumi_hi_m;
|
|
}
|
|
|
|
{
|
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
|
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
|
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
|
|
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
|
|
}
|
|
|
|
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
|
|
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
|
|
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
|
|
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
|
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
vint16m4_t sumi_l1;
|
|
{
|
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
|
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
|
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
|
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
|
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
|
|
sumi_l1 = sumi_hi_m;
|
|
}
|
|
|
|
{
|
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
|
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
|
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
|
|
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
|
|
}
|
|
|
|
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
|
|
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
|
|
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
|
|
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
|
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
vint16m4_t sumi_l2;
|
|
{
|
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
|
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
|
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
|
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
|
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
|
|
sumi_l2 = sumi_hi_m;
|
|
}
|
|
|
|
{
|
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
|
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
|
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
|
|
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
|
|
}
|
|
|
|
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
|
|
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
|
|
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
|
|
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
|
|
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
vint16m4_t sumi_l3;
|
|
{
|
|
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
|
|
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
|
|
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
|
|
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
|
|
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
|
|
sumi_l3 = sumi_hi_m;
|
|
}
|
|
|
|
{
|
|
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
|
|
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
|
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
|
|
sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
|
|
}
|
|
}
|
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
|
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
|
|
}
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
#endif
|
|
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nr % 4 == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
for (int y = 0; y < nr / 4; y++) {
|
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
|
|
|
|
// 4x16 Accumulators
|
|
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
// 4x16 integer accumulators
|
|
vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
|
|
// Accumulation loop.
|
|
for (int i = 0; i < QK4_0 / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);
|
|
const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);
|
|
|
|
sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16);
|
|
sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16);
|
|
sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16);
|
|
sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16);
|
|
|
|
sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16);
|
|
sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16);
|
|
sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16);
|
|
sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16);
|
|
}
|
|
|
|
// Do the final accumulation in i32 to prevent overflow.
|
|
const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);
|
|
const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16);
|
|
const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16);
|
|
const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16);
|
|
|
|
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
|
|
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
|
|
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
|
|
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
|
|
|
|
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
}
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK_K;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nr % 4 == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
for (int y = 0; y < nr / 4; y++) {
|
|
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
|
|
|
|
// 4x16 Accumulators
|
|
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
|
|
// We process 4 sub-blocks at once.
|
|
for (int j = 0; j < QK_K / 128; j++) {
|
|
// Extract the scales and the mins.
|
|
//
|
|
// Low bits.
|
|
vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);
|
|
vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);
|
|
vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);
|
|
|
|
// High bits.
|
|
vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);
|
|
vuint8m2_t scales_hi;
|
|
vuint8m2_t mins_hi;
|
|
if (!j) {
|
|
scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);
|
|
mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);
|
|
} else {
|
|
scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);
|
|
mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);
|
|
}
|
|
vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);
|
|
vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));
|
|
|
|
// Reduce the mins and multiply with `dmin`.
|
|
//
|
|
// Correct in `sumf`.
|
|
vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16);
|
|
|
|
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 0), 16);
|
|
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 1), 16);
|
|
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 2), 16);
|
|
bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
|
|
a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
|
|
a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
|
|
a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
|
|
a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7],
|
|
__riscv_vget_v_i16m4_i16m1(mins, 3), 16);
|
|
|
|
const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(
|
|
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[0], 16);
|
|
const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(
|
|
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[1], 16);
|
|
const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(
|
|
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[2], 16);
|
|
const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(
|
|
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[3], 16);
|
|
|
|
sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16);
|
|
sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16);
|
|
sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16);
|
|
sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16);
|
|
|
|
|
|
// Accumulation for 2 sub-blocks.
|
|
//
|
|
// This might overflow, so we accumulate in two steps.
|
|
//
|
|
// Recheck.
|
|
for (int k = 0; k < 2; k++) {
|
|
// 4x16 integer accumulators
|
|
vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
|
|
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);
|
|
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
|
|
sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16);
|
|
sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16);
|
|
sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16);
|
|
sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16);
|
|
|
|
sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16);
|
|
sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16);
|
|
sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16);
|
|
sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16);
|
|
}
|
|
|
|
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
sumi_0_s_0_16, 16);
|
|
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
sumi_0_s_1_16, 16);
|
|
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
sumi_1_s_0_16, 16);
|
|
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
sumi_1_s_1_16, 16);
|
|
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
sumi_2_s_0_16, 16);
|
|
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
sumi_2_s_1_16, 16);
|
|
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
|
|
sumi_3_s_0_16, 16);
|
|
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
|
|
sumi_3_s_1_16, 16);
|
|
}
|
|
// Accumulation for 2 sub-blocks.
|
|
//
|
|
// This might overflow, so we accumulate in two steps.
|
|
//
|
|
// Recheck.
|
|
for (int k = 0; k < 2; k++) {
|
|
// 4x16 integer accumulators
|
|
vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
|
|
|
|
for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);
|
|
const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
|
|
const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
|
|
|
|
sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16);
|
|
sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16);
|
|
sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16);
|
|
sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16);
|
|
|
|
sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16);
|
|
sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16);
|
|
sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16);
|
|
sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16);
|
|
}
|
|
|
|
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
sumi_0_s_0_16, 16);
|
|
sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
sumi_0_s_1_16, 16);
|
|
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
sumi_1_s_0_16, 16);
|
|
sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
sumi_1_s_1_16, 16);
|
|
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
sumi_2_s_0_16, 16);
|
|
sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
sumi_2_s_1_16, 16);
|
|
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
|
|
sumi_3_s_0_16, 16);
|
|
sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
|
|
sumi_3_s_1_16, 16);
|
|
}
|
|
}
|
|
|
|
const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16);
|
|
const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16);
|
|
const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16);
|
|
const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16);
|
|
|
|
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
}
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nr % 4 == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
|
|
for (int y = 0; y < nr / 4; y++) {
|
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
|
|
|
|
// 4x16 Accumulators
|
|
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
// 4x16 integer accumulators
|
|
vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
|
|
// Accumulation loop.
|
|
for (int i = 0; i < QK4_NL / 2; i++) {
|
|
// Load `b_ptr`.
|
|
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
|
|
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
|
|
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
|
|
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
|
|
|
|
const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16);
|
|
const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16);
|
|
const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16);
|
|
const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16);
|
|
|
|
const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16);
|
|
const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16);
|
|
const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16);
|
|
const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16);
|
|
|
|
sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16);
|
|
sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16);
|
|
sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16);
|
|
sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16);
|
|
}
|
|
|
|
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
|
|
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
|
|
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
|
|
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
|
|
|
|
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
}
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
const int qk = QK8_0;
|
|
const int nb = n / qk;
|
|
const int ncols_interleaved = 16;
|
|
const int blocklen = 1;
|
|
|
|
assert (n % qk == 0);
|
|
assert (nr % 4 == 0);
|
|
assert (nc % ncols_interleaved == 0);
|
|
|
|
UNUSED(s);
|
|
UNUSED(bs);
|
|
UNUSED(vx);
|
|
UNUSED(vy);
|
|
UNUSED(nr);
|
|
UNUSED(nc);
|
|
UNUSED(nb);
|
|
UNUSED(ncols_interleaved);
|
|
UNUSED(blocklen);
|
|
|
|
#if defined __riscv_v_intrinsic
|
|
for (int y = 0; y < nr / 4; y++) {
|
|
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
|
|
|
|
// 4x16 Accumulators
|
|
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
|
|
|
|
for (int l = 0; l < nb; l++) {
|
|
// 4x16 Integer Accumulators
|
|
vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
|
|
|
|
// Accumulation loop.
|
|
for (int i = 0; i < QK8_0; i++) {
|
|
// Load `b_ptr`.
|
|
const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
|
|
// const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);
|
|
|
|
sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16);
|
|
sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16);
|
|
sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16);
|
|
sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16);
|
|
}
|
|
|
|
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
|
|
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
|
|
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
|
|
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
|
|
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
|
|
|
|
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
|
|
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
|
|
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
|
|
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
|
|
}
|
|
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
|
|
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
|
|
}
|
|
}
|
|
return;
|
|
#endif
|
|
ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
}
|
|
|
|
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
assert(n % QK_K == 0);
|
|
const int num_k_blocks = n / QK_K;
|
|
const int N_ROWS_TILE = 4;
|
|
const int N_COLS_TILE = 16;
|
|
assert(nr % N_ROWS_TILE == 0);
|
|
assert(nc % N_COLS_TILE == 0);
|
|
|
|
const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);
|
|
// --- Tiling Loops ---
|
|
#pragma GCC unroll 1
|
|
for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) {
|
|
#pragma GCC unroll 1
|
|
for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {
|
|
// Base Pointers
|
|
const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks;
|
|
const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;
|
|
|
|
// Persistent Float Accumulators
|
|
vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
|
|
|
|
// --- Super-Block Loop (K=0..255) ---
|
|
#pragma GCC unroll 1
|
|
for (int k_block = 0; k_block < num_k_blocks; ++k_block) {
|
|
const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block];
|
|
const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];
|
|
|
|
// 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers)
|
|
vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);
|
|
vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);
|
|
|
|
// 2. Initialize Integer Accumulators
|
|
vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl);
|
|
|
|
const uint8_t* rhs_qs_ptr = rhs_current->qs;
|
|
const uint8_t* rhs_sc_ptr = rhs_current->scales;
|
|
const int8_t* lhs_qs_ptr = lhs_current->qs;
|
|
|
|
// --- Phase Loop (4 phases x 64 elements) ---
|
|
#pragma GCC unroll 1
|
|
for (int phase = 0; phase < 4; ++phase) {
|
|
|
|
// A. Load Scales/Mins for the 4 interleaved sub-blocks
|
|
vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;
|
|
vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;
|
|
|
|
// Unrolled Load Logic
|
|
{
|
|
vuint8mf2_t v_raw;
|
|
// Sub-block 0
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);
|
|
v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
// Sub-block 1
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);
|
|
v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
// Sub-block 2
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);
|
|
v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
// Sub-block 3
|
|
v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);
|
|
v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
|
|
v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
|
|
|
|
rhs_sc_ptr += 64;
|
|
}
|
|
|
|
int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);
|
|
int k_offsets[4] = {0, 32, 64, 96};
|
|
|
|
// B. Inner Dot Product Loop
|
|
#pragma GCC unroll 1
|
|
for (int l = 0; l < 16; ++l) {
|
|
vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);
|
|
rhs_qs_ptr += 16;
|
|
|
|
// Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase)
|
|
|
|
// --- Sub-block 0 ---
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);
|
|
|
|
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4];
|
|
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
}
|
|
// --- Sub-block 1 ---
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);
|
|
|
|
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4];
|
|
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
}
|
|
// --- Sub-block 2 ---
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);
|
|
|
|
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4];
|
|
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
}
|
|
// --- Sub-block 3 ---
|
|
{
|
|
vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);
|
|
vint16m1_t v_w = __riscv_vmul_vv_i16m1(
|
|
__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
|
|
__riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);
|
|
|
|
const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4];
|
|
v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
|
|
v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
|
|
v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
|
|
v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
|
|
}
|
|
}
|
|
|
|
// C CORRECTION
|
|
int sb_base_abs = base_k_phase / 16;
|
|
|
|
// --- Correction Sub-block 0 ---
|
|
{
|
|
int sb_abs = sb_base_abs + (k_offsets[0] / 16);
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);
|
|
|
|
// Row 0
|
|
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
|
|
// Row 1
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
|
|
// Row 2
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
|
|
// Row 3
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
}
|
|
|
|
// --- Correction Sub-block 1 ---
|
|
{
|
|
int sb_abs = sb_base_abs + (k_offsets[1] / 16);
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);
|
|
|
|
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
}
|
|
|
|
// --- Correction Sub-block 2 ---
|
|
{
|
|
int sb_abs = sb_base_abs + (k_offsets[2] / 16);
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);
|
|
|
|
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
}
|
|
|
|
// --- Correction Sub-block 3 ---
|
|
{
|
|
int sb_abs = sb_base_abs + (k_offsets[3] / 16);
|
|
vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);
|
|
|
|
vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
|
|
vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
|
|
vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
|
|
|
|
v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
|
|
v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
|
|
vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
|
|
v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
|
|
}
|
|
|
|
} // End Phase Loop
|
|
|
|
// --- Apply Main Scales ---
|
|
vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);
|
|
vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);
|
|
|
|
{
|
|
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl);
|
|
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl);
|
|
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl);
|
|
}
|
|
// Row 1
|
|
{
|
|
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl);
|
|
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl);
|
|
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl);
|
|
}
|
|
// Row 2
|
|
{
|
|
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl);
|
|
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl);
|
|
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl);
|
|
}
|
|
// Row 3
|
|
{
|
|
vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl);
|
|
vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl);
|
|
v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
|
|
v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl);
|
|
}
|
|
|
|
} // End K-Block
|
|
|
|
__riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl);
|
|
__riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl);
|
|
__riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl);
|
|
__riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl);
|
|
}
|
|
}
|
|
}
|