llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp

3197 lines
173 KiB
C++

#include "ggml.h"
#include "ime_kernels.h"
#include <algorithm>
#include <cmath>
// clang-format off
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Woverlength-strings"
#pragma GCC diagnostic ignored "-Wcast-qual"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
// clang-format on
namespace sqnbitgemm_spacemit_ime {
#define QUANTIZEM4ROW_KERNEL \
"vmv.s.x v16, zero \n\t" \
"vfabs.v v8, v0 \n\t" \
"vfredmax.vs v16, v8, v16 \n\t" \
"vfmv.f.s f10, v16 \n\t" \
"fmul.s f10, f10, %[RMAXREC] \n\t" \
"fsw f10, (a1) \n\t" \
"fdiv.s f11, %[FONE], f10 \n\t" \
"vfmul.vf v16, v0, f11 \n\t" \
"vfcvt.x.f.v v16, v16 \n\t" \
"vsetvli t0, zero, e16, mf2 \n\t" \
"vnclip.wx v16, v16, zero \n\t" \
"vnclip.wx v17, v17, zero \n\t" \
"vnclip.wx v18, v18, zero \n\t" \
"vnclip.wx v19, v19, zero \n\t" \
"vnclip.wx v20, v20, zero \n\t" \
"vnclip.wx v21, v21, zero \n\t" \
"vnclip.wx v22, v22, zero \n\t" \
"vnclip.wx v23, v23, zero \n\t" \
"vsetvli t0, zero, e8, mf4 \n\t" \
"vnclip.wx v24, v16, zero \n\t" \
"vnclip.wx v25, v17, zero \n\t" \
"vnclip.wx v26, v18, zero \n\t" \
"vnclip.wx v27, v19, zero \n\t" \
"vnclip.wx v28, v20, zero \n\t" \
"vnclip.wx v29, v21, zero \n\t" \
"vnclip.wx v30, v22, zero \n\t" \
"vnclip.wx v31, v23, zero \n\t"
#define QUANTIZEM4ROW_STORE \
"addi t1, %[BlkLen], 0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v24, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v25, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v26, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v27, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v28, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v29, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v30, (s1) \n\t" \
"addi s1, s1, 32 \n\t" \
"sub t1, t1, t0 \n\t" \
"vsetvli t0, t1, e8, mf4 \n\t" \
"vse8.v v31, (s1) \n\t"
namespace ime1 {
void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
const float fone = 1.0f;
if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
for (size_t row_index = 0; row_index < 4; ++row_index) {
const float * SRC = A + row_index * CountK;
std::byte * DST = QuantA + row_index * sizeof(float);
const size_t offset = (4 - row_index) * 4 + row_index * 8;
const size_t stride = 4 * (sizeof(float) + BlkLen);
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"addi t2, %[CountK], 0 \n\t"
"addi a1, %[DST], 0 \n\t"
"blt t2, %[BlkLen], TAIL%= \n\t"
"LOOP%=: \n\t"
"vsetvli t0, %[BlkLen], e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"sub t2, t2, t0 \n\t"
"slli t1, t0, 2 \n\t"
"add %[SRC], %[SRC], t1 \n\t"
"add s1, a1, %[OFFSET] \n\t"
QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
"add a1, a1, %[STRIDE] \n\t"
"bge t2, %[BlkLen], LOOP%= \n\t"
"TAIL%=: \n\t"
"blez t2, QUIT%= \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"vsetvli t0, t2, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"add s1, a1, %[OFFSET] \n\t"
QUANTIZEM4ROW_KERNEL
"addi t3, %[BlkLen], 0 \n\t"
"addi s2, s1, 0 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vxor.vv v8, v8, v8 \n\t"
"SET_ZERO%=: \n\t"
"vse8.v v8, (s2) \n\t"
"addi s2, s2, 32 \n\t"
"addi t3, t3, -8 \n\t"
"bnez t3, SET_ZERO%= \n\t"
QUANTIZEM4ROW_STORE
"QUIT%=: \n\t"
: [SRC] "+r"(SRC)
: [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
[CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
: "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
}
} else if (BlkLen == 128) {
for (size_t row_index = 0; row_index < 4; ++row_index) {
const float * SRC = A + row_index * CountK;
std::byte * DST = QuantA + row_index * sizeof(float);
const size_t offset = (4 - row_index) * 4 + row_index * 8;
const size_t stride = 4 * (sizeof(float) + BlkLen);
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"li t6, 32 \n\t"
"addi t2, %[CountK], 0 \n\t"
"addi a1, %[DST], 0 \n\t"
"add s1, a1, %[OFFSET] \n\t"
"blt t2, %[BlkLen], TAIL%= \n\t"
"LOOP%=: \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"addi t2, t2, -128 \n\t"
"QUANTIZE%=: \n\t"
"add s1, a1, %[OFFSET] \n\t"
"vfabs.v v16, v0 \n\t"
"vfabs.v v24, v8 \n\t"
"vfmax.vv v16, v24, v16 \n\t"
"vfredmax.vs v24, v16, v24 \n\t"
"vfmv.f.s f10, v24 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (a1) \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vfmul.vf v16, v0, f11 \n\t"
"vfmul.vf v24, v8, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v20, v24, zero \n\t"
"vsetvli t0, zero, e8, m4 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vsetvli t0, zero, e64, m4 \n\t"
"vsse64.v v16, (s1), t6 \n\t"
"add a1, a1, %[STRIDE] \n\t"
"bge t2, %[BlkLen], LOOP%= \n\t"
"TAIL%=: \n\t"
"blez t2, QUIT%= \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vxor.vv v8, v8, v8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"vsetvli t0, t2, e32, m8 \n\t"
"sub t2, t2, t0 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vsetvli t0, t2, e32, m8 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"sub t2, t2, t2 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"jal x0, QUANTIZE%= \n\t"
"QUIT%=: \n\t"
: [SRC] "+r"(SRC)
: [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
[CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
: "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
}
} else if (BlkLen == 256) {
for (size_t row_index = 0; row_index < 4; ++row_index) {
const float * SRC = A + row_index * CountK;
std::byte * DST = QuantA + row_index * sizeof(float);
const size_t offset = (4 - row_index) * 4 + row_index * 8;
const size_t stride = 4 * (sizeof(float) + BlkLen);
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"li t6, 32 \n\t"
"addi t2, %[CountK], 0 \n\t"
"addi a1, %[DST], 0 \n\t"
"add s1, a1, %[OFFSET] \n\t"
"blt t2, %[BlkLen], TAIL%= \n\t"
"LOOP%=: \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v16, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v24, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], -768 \n\t"
"addi t2, t2, -256 \n\t"
"vfabs.v v0, v0 \n\t"
"vfabs.v v8, v8 \n\t"
"vfabs.v v16, v16 \n\t"
"vfabs.v v24, v24 \n\t"
"vfmax.vv v8, v0, v8 \n\t"
"vfmax.vv v24, v24, v16 \n\t"
"vfmax.vv v8, v8, v24 \n\t"
"vfredmax.vs v24, v8, v24 \n\t"
"vfmv.f.s f10, v24 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v16, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v24, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"QUANTIZE%=: \n\t"
"add s1, a1, %[OFFSET] \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (a1) \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vfmul.vf v0, v0, f11 \n\t"
"vfmul.vf v8, v8, f11 \n\t"
"vfmul.vf v16, v16, f11 \n\t"
"vfmul.vf v24, v24, f11 \n\t"
"vfcvt.x.f.v v0, v0 \n\t"
"vfcvt.x.f.v v8, v8 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v0, v0, zero \n\t"
"vnclip.wx v4, v8, zero \n\t"
"vnclip.wx v8, v16, zero \n\t"
"vnclip.wx v12, v24, zero \n\t"
"vsetvli t0, zero, e8, m4 \n\t"
"vnclip.wx v0, v0, zero \n\t"
"vnclip.wx v4, v8, zero \n\t"
"vsetvli t0, zero, e64, m8 \n\t"
"vsse64.v v0, (s1), t6 \n\t"
"add a1, a1, %[STRIDE] \n\t"
"bge t2, %[BlkLen], LOOP%= \n\t"
"TAIL%=: \n\t"
"blez t2, QUIT%= \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vxor.vv v8, v8, v8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t1, t2, 0 \n\t"
"vsetvli t0, t1, e32, m8 \n\t"
"sub t1, t1, t0 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vsetvli t0, t1, e32, m8 \n\t"
"sub t1, t1, t0 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vsetvli t0, t1, e32, m8 \n\t"
"sub t1, t1, t0 \n\t"
"vle32.v v16, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vsetvli t0, t1, e32, m8 \n\t"
"vle32.v v24, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], -768 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vfabs.v v0, v0 \n\t"
"vfabs.v v8, v8 \n\t"
"vfabs.v v16, v16 \n\t"
"vfabs.v v24, v24 \n\t"
"vfmax.vv v8, v0, v8 \n\t"
"vfmax.vv v24, v16, v24 \n\t"
"vfmax.vv v8, v8, v24 \n\t"
"vfredmax.vs v24, v8, v24 \n\t"
"vfmv.f.s f10, v24 \n\t"
"add s1, a1, %[OFFSET] \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (a1) \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vsetvli t0, zero, e64, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vsse64.v v0, (s1), t6 \n\t"
"TAIL_LOOP%=: \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vsetvli t0, t2, e32, m1 \n\t"
"sub t2, t2, t0 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 32 \n\t"
"vfmul.vf v1, v0, f11 \n\t"
"vfcvt.x.f.v v2, v1 \n\t"
"vsetvli t0, zero, e16, mf2 \n\t"
"vnclip.wx v3, v2, zero \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vnclip.wx v3, v3, zero \n\t"
"vse8.v v3, (s1) \n\t"
"addi s1, s1, 32 \n\t"
"bnez t2, TAIL_LOOP%= \n\t"
"QUIT%=: \n\t"
: [SRC] "+r"(SRC)
: [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
[CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
: "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
}
}
}
void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
const float * SRC = A;
std::byte * DST = QuantA;
constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
const float fone = 1.0f;
std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
if (CountK <= BlkLen) {
float max_abs_A = 0.0f;
for (size_t k = 0; k < CountK; k++) {
max_abs_A = std::max(max_abs_A, fabsf(A[k]));
}
float scale_A = max_abs_A * range_max_reciprocal;
((float *) QuantA)[0] = scale_A;
auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
for (size_t k = 0; k < CountK; k++) {
QuantAData_offset[k] =
(int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
(float) std::numeric_limits<int8_t>::max());
}
for (size_t k = CountK; k < BlkLen; k++) {
QuantAData_offset[k] = 0;
}
return;
}
if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
__asm__ volatile(
"vsetvli t0, zero, e8, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"LOOP%=: \n\t"
"vsetvli t0, %[CNT], e8, m8 \n\t"
"vse8.v v24, (%[DST]) \n\t"
"addi %[DST], %[DST], 128 \n\t"
"sub %[CNT], %[CNT], t0 \n\t"
"bnez %[CNT], LOOP%= \n\t"
: [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
:
: "cc", "t0");
}
if (BlkLen == 16) {
float buffer[64] = { 0.0f };
__asm__ volatile(
"addi t3, zero, 16*8 \n\t"
"addi t2, zero, 16 \n\t"
"blt %[K], t3, LOOP_K%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_MAIN%=: \n\t"
"vsetvli t1, zero, e32, m2 \n\t"
"addi %[K], %[K], -128 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v2, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v4, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v6, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v10, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v12, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"vle32.v v14, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"addi a1, %[BUFFER], 0 \n\t"
"vfabs.v v16, v0 \n\t"
"vfabs.v v18, v2 \n\t"
"vfabs.v v20, v4 \n\t"
"vfabs.v v22, v6 \n\t"
"vfabs.v v24, v8 \n\t"
"vfabs.v v26, v10 \n\t"
"vfabs.v v28, v12 \n\t"
"vfabs.v v30, v14 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfmax.vv v18, v18, v19 \n\t"
"vfmax.vv v20, v20, v21 \n\t"
"vfmax.vv v22, v22, v23 \n\t"
"vfmax.vv v24, v24, v25 \n\t"
"vfmax.vv v26, v26, v27 \n\t"
"vfmax.vv v28, v28, v29 \n\t"
"vfmax.vv v30, v30, v31 \n\t"
"vse32.v v16, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v18, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v20, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v22, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v24, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v26, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v28, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vse32.v v30, (a1) \n\t"
"addi a1, %[BUFFER], 0 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f10, f3, f7 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f10, %[FONE], f10 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f11, f3, f7 \n\t"
"fmul.s f11, f11, %[RMAXREC] \n\t"
"fsw f11, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f11, %[FONE], f11 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f12, f3, f7 \n\t"
"fmul.s f12, f12, %[RMAXREC] \n\t"
"fsw f12, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f12, %[FONE], f12 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f13, f3, f7 \n\t"
"fmul.s f13, f13, %[RMAXREC] \n\t"
"fsw f13, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f13, %[FONE], f13 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f14, f3, f7 \n\t"
"fmul.s f14, f14, %[RMAXREC] \n\t"
"fsw f14, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f14, %[FONE], f14 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f15, f3, f7 \n\t"
"fmul.s f15, f15, %[RMAXREC] \n\t"
"fsw f15, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f15, %[FONE], f15 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f16, f3, f7 \n\t"
"fmul.s f16, f16, %[RMAXREC] \n\t"
"fsw f16, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"fdiv.s f16, %[FONE], f16 \n\t"
"flw f0, (a1) \n\t"
"flw f1, 4(a1) \n\t"
"flw f2, 8(a1) \n\t"
"flw f3, 12(a1) \n\t"
"flw f4, 16(a1) \n\t"
"flw f5, 20(a1) \n\t"
"flw f6, 24(a1) \n\t"
"flw f7, 28(a1) \n\t"
"addi a1, a1, 32 \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f17, f3, f7 \n\t"
"fmul.s f17, f17, %[RMAXREC] \n\t"
"fsw f17, (%[DST]) \n\t"
"addi %[DST], %[DST], -136 \n\t"
"fdiv.s f17, %[FONE], f17 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmul.vf v16, v0, f10 \n\t"
"vfmul.vf v18, v2, f11 \n\t"
"vfmul.vf v20, v4, f12 \n\t"
"vfmul.vf v22, v6, f13 \n\t"
"vfmul.vf v24, v8, f14 \n\t"
"vfmul.vf v26, v10, f15 \n\t"
"vfmul.vf v28, v12, f16 \n\t"
"vfmul.vf v30, v14, f17 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v18, v18 \n\t"
"vfcvt.x.f.v v20, v20 \n\t"
"vfcvt.x.f.v v22, v22 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vfcvt.x.f.v v26, v26 \n\t"
"vfcvt.x.f.v v28, v28 \n\t"
"vfcvt.x.f.v v30, v30 \n\t"
"vsetvli t0, zero, e16, m1 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v18, v18, zero \n\t"
"vnclip.wx v20, v20, zero \n\t"
"vnclip.wx v22, v22, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vnclip.wx v26, v26, zero \n\t"
"vnclip.wx v28, v28, zero \n\t"
"vnclip.wx v30, v30, zero \n\t"
"vsetvli t0, t1, e8, mf2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v18, v18, zero \n\t"
"vnclip.wx v20, v20, zero \n\t"
"vnclip.wx v22, v22, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vnclip.wx v26, v26, zero \n\t"
"vnclip.wx v28, v28, zero \n\t"
"vnclip.wx v30, v30, zero \n\t"
"vse8.v v16, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v18, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v20, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v22, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v24, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v26, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v28, (%[DST]) \n\t"
"addi %[DST], %[DST], 20 \n\t"
"vse8.v v30, (%[DST]) \n\t"
"addi %[DST], %[DST], 16 \n\t"
"bge %[K], t3, LOOP_MAIN%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_K%=: \n\t"
"vsetvli t1, %[K], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 64 \n\t"
"sub %[K], %[K], t1 \n\t"
"vfabs.v v16, v0 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vse32.v v16, (%[BUFFER]) \n\t"
"flw f0, (%[BUFFER]) \n\t"
"flw f1, 4(%[BUFFER]) \n\t"
"flw f2, 8(%[BUFFER]) \n\t"
"flw f3, 12(%[BUFFER]) \n\t"
"flw f4, 16(%[BUFFER]) \n\t"
"flw f5, 20(%[BUFFER]) \n\t"
"flw f6, 24(%[BUFFER]) \n\t"
"flw f7, 28(%[BUFFER]) \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f10, f3, f7 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (%[DST]) \n\t"
"addi %[DST], %[DST], 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmul.vf v16, v0, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vsetvli t0, zero, e16, m1 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vsetvli t0, t1, e8, mf2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vse8.v v16, (%[DST]) \n\t"
"addi %[DST], %[DST], 16 \n\t"
"bge %[K], t2, LOOP_K%= \n\t"
"TAIL%=: \n\t"
"blez %[K], END%= \n\t"
"vsetvli t0, t3, e32, m2 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"jal x0, LOOP_K%= \n\t"
"END%=: \n\t"
: [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
: [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
: "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
"f13", "f14", "f15", "f16", "f17");
} else if (BlkLen == 32) {
__asm__ volatile(
"addi t3, zero, 32*4 \n\t"
"addi t2, zero, 32 \n\t"
"addi a1, %[SRC], 0 \n\t"
"addi a2, %[SRC], 128 \n\t"
"addi a3, %[SRC], 256 \n\t"
"addi a4, %[SRC], 384 \n\t"
"addi s1, %[DST], 0 \n\t"
"addi s2, %[DST], 36 \n\t"
"addi s3, %[DST], 72 \n\t"
"addi s4, %[DST], 108 \n\t"
"blt %[K], t3, LOOP_K%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_MAIN%=: \n\t"
"vsetvli t1, zero, e32, m4 \n\t"
"addi %[K], %[K], -128 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 512 \n\t"
"vle32.v v4, (a2) \n\t"
"addi a2, a2, 512 \n\t"
"vle32.v v8, (a3) \n\t"
"addi a3, a3, 512 \n\t"
"vle32.v v12, (a4) \n\t"
"addi a4, a4, 512 \n\t"
"vfabs.v v16, v0 \n\t"
"vfabs.v v20, v4 \n\t"
"vfabs.v v24, v8 \n\t"
"vfabs.v v28, v12 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v16, v16, v18 \n\t"
"vfmax.vv v20, v20, v22 \n\t"
"vfmax.vv v24, v24, v26 \n\t"
"vfmax.vv v28, v28, v30 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfmax.vv v20, v20, v21 \n\t"
"vfmax.vv v24, v24, v25 \n\t"
"vfmax.vv v28, v28, v29 \n\t"
"vfredmax.vs v17, v16, v17 \n\t"
"vfredmax.vs v21, v20, v21 \n\t"
"vfredmax.vs v25, v24, v25 \n\t"
"vfredmax.vs v29, v28, v29 \n\t"
"vfmv.f.s f10, v17 \n\t"
"vfmv.f.s f11, v21 \n\t"
"vfmv.f.s f12, v25 \n\t"
"vfmv.f.s f13, v29 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fmul.s f11, f11, %[RMAXREC] \n\t"
"fmul.s f12, f12, %[RMAXREC] \n\t"
"fmul.s f13, f13, %[RMAXREC] \n\t"
"fsw f10, (s1) \n\t"
"addi s1, s1, 4 \n\t"
"fsw f11, (s2) \n\t"
"addi s2, s2, 4 \n\t"
"fsw f12, (s3) \n\t"
"addi s3, s3, 4 \n\t"
"fsw f13, (s4) \n\t"
"addi s4, s4, 4 \n\t"
"fdiv.s f10, %[FONE], f10 \n\t"
"fdiv.s f11, %[FONE], f11 \n\t"
"fdiv.s f12, %[FONE], f12 \n\t"
"fdiv.s f13, %[FONE], f13 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmul.vf v16, v0, f10 \n\t"
"vfmul.vf v20, v4, f11 \n\t"
"vfmul.vf v24, v8, f12 \n\t"
"vfmul.vf v28, v12, f13 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v20, v20 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vfcvt.x.f.v v28, v28 \n\t"
"vsetvli t0, zero, e16, m2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v20, v20, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vnclip.wx v28, v28, zero \n\t"
"vsetvli t0, t1, e8, m1 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v20, v20, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vnclip.wx v28, v28, zero \n\t"
"vse8.v v16, (s1) \n\t"
"addi s1, s1, 140 \n\t"
"vse8.v v20, (s2) \n\t"
"addi s2, s2, 140 \n\t"
"vse8.v v24, (s3) \n\t"
"addi s3, s3, 140 \n\t"
"vse8.v v28, (s4) \n\t"
"addi s4, s4, 140 \n\t"
"bge %[K], t3, LOOP_MAIN%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_K%=: \n\t"
"vsetvli t1, %[K], e32, m4 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 128 \n\t"
"sub %[K], %[K], t1 \n\t"
"vfabs.v v16, v0 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v16, v16, v18 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfredmax.vs v17, v16, v17 \n\t"
"vfmv.f.s f10, v17 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (s1) \n\t"
"addi s1, s1, 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmul.vf v16, v0, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vsetvli t0, zero, e16, m2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vse8.v v16, (s1) \n\t"
"addi s1, s1, 32 \n\t"
"bge %[K], t2, LOOP_K%= \n\t"
"TAIL%=: \n\t"
"blez %[K], END%= \n\t"
"vsetvli t0, t3, e32, m4 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"jal x0, LOOP_K%= \n\t"
"END%=: \n\t"
: [K] "+r"(CountK)
: [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
: "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
} else if (BlkLen == 64) {
__asm__ volatile(
"addi t3, zero, 64*2 \n\t"
"addi t2, zero, 64 \n\t"
"addi a1, %[SRC], 0 \n\t"
"addi a2, %[SRC], 256 \n\t"
"addi s1, %[DST], 0 \n\t"
"addi s2, %[DST], 68 \n\t"
"blt %[K], t3, LOOP_K%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_MAIN%=: \n\t"
"vsetvli t1, zero, e32, m8 \n\t"
"addi %[K], %[K], -128 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 512 \n\t"
"vle32.v v8, (a2) \n\t"
"addi a2, a2, 512 \n\t"
"vfabs.v v16, v0 \n\t"
"vfabs.v v24, v8 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmax.vv v16, v16, v20 \n\t"
"vfmax.vv v24, v24, v28 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v16, v16, v18 \n\t"
"vfmax.vv v24, v24, v26 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfmax.vv v24, v24, v25 \n\t"
"vfredmax.vs v17, v16, v17 \n\t"
"vfredmax.vs v25, v24, v25 \n\t"
"vfmv.f.s f10, v17 \n\t"
"vfmv.f.s f11, v25 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fmul.s f11, f11, %[RMAXREC] \n\t"
"fsw f10, (s1) \n\t"
"addi s1, s1, 4 \n\t"
"fsw f11, (s2) \n\t"
"addi s2, s2, 4 \n\t"
"fdiv.s f10, %[FONE], f10 \n\t"
"fdiv.s f11, %[FONE], f11 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vfmul.vf v16, v0, f10 \n\t"
"vfmul.vf v24, v8, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vsetvli t0, t1, e8, m2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v24, v24, zero \n\t"
"vse8.v v16, (s1) \n\t"
"addi s1, s1, 132 \n\t"
"vse8.v v24, (s2) \n\t"
"addi s2, s2, 132 \n\t"
"bge %[K], t3, LOOP_MAIN%= \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_K%=: \n\t"
"vsetvli t1, %[K], e32, m8 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 256 \n\t"
"sub %[K], %[K], t1 \n\t"
"vfabs.v v16, v0 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmax.vv v16, v16, v20 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v16, v16, v18 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v16, v16, v17 \n\t"
"vfredmax.vs v17, v16, v17 \n\t"
"vfmv.f.s f10, v17 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (s1) \n\t"
"addi s1, s1, 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vfmul.vf v16, v0, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vsetvli t0, zero, e8, m2 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vse8.v v16, (s1) \n\t"
"addi s1, s1, 64 \n\t"
"bge %[K], t2, LOOP_K%= \n\t"
"TAIL%=: \n\t"
"blez %[K], END%= \n\t"
"vsetvli t0, t3, e32, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"jal x0, LOOP_K%= \n\t"
"END%=: \n\t"
: [K] "+r"(CountK)
: [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
: "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
} else if (BlkLen == 128) {
__asm__ volatile(
"addi t2, zero, 128 \n\t"
"addi a1, %[SRC], 0 \n\t"
"addi a2, %[SRC], 256 \n\t"
"blt %[K], t2, TAIL%= \n\t"
"LOOP_K%=: \n\t"
"vsetvli t1, zero, e32, m8 \n\t"
"vle32.v v0, (a1) \n\t"
"addi a1, a1, 512 \n\t"
"vle32.v v8, (a2) \n\t"
"addi a2, a2, 512 \n\t"
"sub %[K], %[K], t2 \n\t"
"QUANT%=: \n\t"
"vfabs.v v16, v0 \n\t"
"vfabs.v v24, v8 \n\t"
"vfmax.vv v24, v16, v24 \n\t"
"vsetvli t1, zero, e32, m4 \n\t"
"vfmax.vv v28, v24, v28 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v30, v28, v30 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v30, v30, v31 \n\t"
"vfredmax.vs v31, v30, v31 \n\t"
"vfmv.f.s f10, v31 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (%[DST]) \n\t"
"addi %[DST], %[DST], 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vfmul.vf v16, v0, f11 \n\t"
"vfmul.vf v24, v8, f11 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vnclip.wx v20, v24, zero \n\t"
"vsetvli t0, zero, e8, m4 \n\t"
"vnclip.wx v16, v16, zero \n\t"
"vse8.v v16, (%[DST]) \n\t"
"addi %[DST], %[DST], 128 \n\t"
"bge %[K], t2, LOOP_K%= \n\t"
"TAIL%=: \n\t"
"blez %[K], END%= \n\t"
"vsetvli t1, zero, e32, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vxor.vv v8, v8, v8 \n\t"
"vsetvli t0, %[K], e32, m8 \n\t"
"vle32.v v0, (a1) \n\t"
"sub %[K], %[K], t0 \n\t"
"vsetvli t0, %[K], e32, m8 \n\t"
"vle32.v v8, (a2) \n\t"
"sub %[K], %[K], t0 \n\t"
"vsetvli t1, zero, e32, m8 \n\t"
"jal x0, QUANT%= \n\t"
"END%=: \n\t"
: [DST] "+r"(DST), [K] "+r"(CountK)
: [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
: "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
} else {
float buffer[8] = { 0.0f };
size_t cnt = BlkLen / 256;
__asm__ volatile(
"slli t3, %[BLK], 2 \n\t"
"blt %[K], %[BLK], LOOP_TAIL%= \n\t"
"LOOP_MAIN%=: \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vxor.vv v31, v31, v31 \n\t"
"vse32.v v31, (%[BUFFER]) \n\t"
"addi t6, %[CNT], 0 \n\t"
"LOOP_CMP%=: \n\t"
"addi t6, t6, -1 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v16, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v24, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vfabs.v v0, v0 \n\t"
"vfabs.v v8, v8 \n\t"
"vfabs.v v16, v16 \n\t"
"vfabs.v v24, v24 \n\t"
"vfmax.vv v8, v0, v8 \n\t"
"vfmax.vv v16, v16, v24 \n\t"
"vfmax.vv v0, v0, v16 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmax.vv v0, v0, v4 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v0, v0, v2 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v0, v0, v1 \n\t"
"vle32.v v30, (%[BUFFER]) \n\t"
"vfmax.vv v31, v30, v0 \n\t"
"vse32.v v31, (%[BUFFER]) \n\t"
"bnez t6, LOOP_CMP%= \n\t"
"sub %[SRC], %[SRC], t3 \n\t"
"addi t6, %[CNT], 0 \n\t"
"flw f0, (%[BUFFER]) \n\t"
"flw f1, 4(%[BUFFER]) \n\t"
"flw f2, 8(%[BUFFER]) \n\t"
"flw f3, 12(%[BUFFER]) \n\t"
"flw f4, 16(%[BUFFER]) \n\t"
"flw f5, 20(%[BUFFER]) \n\t"
"flw f6, 24(%[BUFFER]) \n\t"
"flw f7, 28(%[BUFFER]) \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f10, f3, f7 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (%[DST]) \n\t"
"addi %[DST], %[DST], 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"addi t6, %[CNT], 0 \n\t"
"LOOP_QUANT%=: \n\t"
"addi t6, t6, -1 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v8, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v16, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vle32.v v24, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vfmul.vf v0, v0, f11 \n\t"
"vfmul.vf v8, v8, f11 \n\t"
"vfmul.vf v16, v16, f11 \n\t"
"vfmul.vf v24, v24, f11 \n\t"
"vfcvt.x.f.v v0, v0 \n\t"
"vfcvt.x.f.v v8, v8 \n\t"
"vfcvt.x.f.v v16, v16 \n\t"
"vfcvt.x.f.v v24, v24 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v0, v0, zero \n\t"
"vnclip.wx v4, v8, zero \n\t"
"vnclip.wx v8, v16, zero \n\t"
"vnclip.wx v12, v24, zero \n\t"
"vsetvli t0, zero, e8, m4 \n\t"
"vnclip.wx v0, v0, zero \n\t"
"vnclip.wx v4, v8, zero \n\t"
"vse8.v v0, (%[DST]) \n\t"
"addi %[DST], %[DST], 128 \n\t"
"vse8.v v4, (%[DST]) \n\t"
"addi %[DST], %[DST], 128 \n\t"
"bnez t6, LOOP_QUANT%= \n\t"
"sub %[K], %[K], %[BLK] \n\t"
"bge %[K], %[BLK], LOOP_MAIN%= \n\t"
"blez %[K], END%= \n\t"
"LOOP_TAIL%=: \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vxor.vv v31, v31, v31 \n\t"
"vse32.v v31, (%[BUFFER]) \n\t"
"addi t6, %[K], 0 \n\t"
"addi s1, %[SRC], 0 \n\t"
"TAIL_CMP%=: \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vsetvli t0, t6, e32, m8 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi %[SRC], %[SRC], 256 \n\t"
"sub t6, t6, t0 \n\t"
"vfabs.v v0, v0 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vfmax.vv v0, v0, v4 \n\t"
"vsetvli t0, zero, e32, m2 \n\t"
"vfmax.vv v0, v0, v2 \n\t"
"vsetvli t0, zero, e32, m1 \n\t"
"vfmax.vv v0, v0, v1 \n\t"
"vle32.v v30, (%[BUFFER]) \n\t"
"vfmax.vv v31, v30, v0 \n\t"
"vse32.v v31, (%[BUFFER]) \n\t"
"bnez t6, TAIL_CMP%= \n\t"
"addi t6, %[K], 0 \n\t"
"flw f0, (%[BUFFER]) \n\t"
"flw f1, 4(%[BUFFER]) \n\t"
"flw f2, 8(%[BUFFER]) \n\t"
"flw f3, 12(%[BUFFER]) \n\t"
"flw f4, 16(%[BUFFER]) \n\t"
"flw f5, 20(%[BUFFER]) \n\t"
"flw f6, 24(%[BUFFER]) \n\t"
"flw f7, 28(%[BUFFER]) \n\t"
"fmax.s f1, f0, f1 \n\t"
"fmax.s f3, f2, f3 \n\t"
"fmax.s f5, f4, f5 \n\t"
"fmax.s f7, f6, f7 \n\t"
"fmax.s f3, f1, f3 \n\t"
"fmax.s f7, f5, f7 \n\t"
"fmax.s f10, f3, f7 \n\t"
"fmul.s f10, f10, %[RMAXREC] \n\t"
"fsw f10, (%[DST]) \n\t"
"addi %[DST], %[DST], 4 \n\t"
"fdiv.s f11, %[FONE], f10 \n\t"
"addi t6, %[K], 0 \n\t"
"TAIL_QUANT%=: \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v0, v0, v0 \n\t"
"vsetvli t1, t6, e32, m8 \n\t"
"vle32.v v0, (s1) \n\t"
"addi s1, s1, 256 \n\t"
"sub t6, t6, t1 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vfmul.vf v0, v0, f11 \n\t"
"vfcvt.x.f.v v0, v0 \n\t"
"vsetvli t0, zero, e16, m4 \n\t"
"vnclip.wx v0, v0, zero \n\t"
"vsetvli t0, t1, e8, m2 \n\t"
"vnclip.wx v0, v0, zero \n\t"
"vse8.v v0, (%[DST]) \n\t"
"addi %[DST], %[DST], 64 \n\t"
"bnez t6, TAIL_QUANT%= \n\t"
"END%=: \n\t"
: [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
: [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
[CNT] "r"(cnt)
: "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
}
}
} // namespace ime1
namespace {
#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
"vmadot v16, v14, v0 \n\t" \
"vmadot v18, v14, v1 \n\t" \
"vmadot v20, v14, v2 \n\t" \
"vmadot v22, v14, v3 \n\t" \
"vmadot v16, v15, v4 \n\t" \
"vmadot v18, v15, v5 \n\t" \
"vmadot v20, v15, v6 \n\t" \
"vmadot v22, v15, v7 \n\t"
#define SQ4BIT_KERNEL_ACC_1X4X4 \
"vfcvt.f.x.v v16, v16 \n\t" \
"vfcvt.f.x.v v18, v18 \n\t" \
"vfcvt.f.x.v v20, v20 \n\t" \
"vfcvt.f.x.v v22, v22 \n\t" \
"addi s2, s1, 16 \n\t" \
"addi s3, s1, 32 \n\t" \
"addi s4, s1, 48 \n\t" \
"addi s6, s5, 12 \n\t" \
"vfmacc.vv v28, v16, v24 \n\t" \
"vfmacc.vv v29, v18, v25 \n\t" \
"vfmacc.vv v30, v20, v26 \n\t" \
"vfmacc.vv v31, v22, v27 \n\t"
#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
"vfcvt.f.x.v v16, v16 \n\t" \
"vfcvt.f.x.v v18, v18 \n\t" \
"vfcvt.f.x.v v20, v20 \n\t" \
"vfcvt.f.x.v v22, v22 \n\t" \
"addi s2, s1, 8 \n\t" \
"addi s3, s1, 16 \n\t" \
"addi s4, s1, 24 \n\t" \
"addi s6, s5, 12 \n\t" \
"vfmacc.vv v28, v16, v24 \n\t" \
"vfmacc.vv v29, v18, v25 \n\t" \
"vfmacc.vv v30, v20, v26 \n\t" \
"vfmacc.vv v31, v22, v27 \n\t"
#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
"vle8.v v4, (s1) \n\t" \
"addi s1, s1, 128 \n\t" \
"vle8.v v5, (s2) \n\t" \
"addi s2, s2, 128 \n\t" \
"vle8.v v6, (s3) \n\t" \
"addi s3, s3, 128 \n\t" \
"vle8.v v7, (s4) \n\t" \
"addi s4, s4, 128 \n\t" \
"vsetvli t0, zero, e8, mf4 \n\t" \
"vle8.v v14, (s5) \n\t" \
"addi s5, s5, 16 \n\t" \
"vle8.v v15, (s6) \n\t" \
"addi s6, s6, 16 \n\t" \
"addi t5, t5, -1 \n\t" \
"vsetvli t0, zero, e8, m1 \n\t" \
"vand.vi v0, v4, 15 \n\t" \
"vand.vi v1, v5, 15 \n\t" \
"vand.vi v2, v6, 15 \n\t" \
"vand.vi v3, v7, 15 \n\t" \
"vsrl.vi v4, v4, 4 \n\t" \
"vsrl.vi v5, v5, 4 \n\t" \
"vsrl.vi v6, v6, 4 \n\t" \
"vsrl.vi v7, v7, 4 \n\t"
#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
"vsetvli t0, zero, e8, mf2 \n\t" \
"vle8.v v1, (s7) \n\t" \
"vsetvli t0, zero, e8, m1 \n\t" \
"vrgather.vv v8, v1, v13 \n\t" \
"vadd.vi v13, v13, 4 \n\t" \
"vrgather.vv v9, v1, v13 \n\t" \
"vadd.vi v13, v13, 4 \n\t" \
"vrgather.vv v10, v1, v13 \n\t" \
"vadd.vi v13, v13, 4 \n\t" \
"vrgather.vv v11, v1, v13 \n\t" \
"vadd.vi v13, v13, -12 \n\t"
// using for M4Kernel
#define LOAD_B_16x8x2 \
"vsetvli t0, zero, e8, m1 \n\t" \
"vle8.v v6, (s1) \n\t" \
"addi s1, s1, 32*4 \n\t" \
"vle8.v v7, (s2) \n\t" \
"addi s2, s2, 32*4 \n\t" \
"vle8.v v8, (s3) \n\t" \
"addi s3, s3, 32*4 \n\t" \
"vle8.v v9, (s4) \n\t" \
"addi s4, s4, 32*4 \n\t" \
\
"vand.vi v2, v6, 15 \n\t" \
"vand.vi v3, v7, 15 \n\t" \
"vand.vi v4, v8, 15 \n\t" \
"vand.vi v5, v9, 15 \n\t" \
\
"vsrl.vi v6, v6, 4 \n\t" \
"vsrl.vi v7, v7, 4 \n\t" \
"vsrl.vi v8, v8, 4 \n\t" \
"vsrl.vi v9, v9, 4 \n\t"
// [s2|s5, s3, s4, s6]
#define LOAD_SCALE_4x16_FP16 \
"addi s2, s5, -8 \n\t" \
"addi s3, s5, 8 \n\t" \
"addi s4, s5, 16 \n\t" \
"addi s6, s5, 24 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e16, mf4 \n\t" \
"vle16.v v9, (s5) \n\t" \
"vle16.v v11, (s3) \n\t" \
"vle16.v v13, (s4) \n\t" \
"vle16.v v15, (s6) \n\t" \
"vsetvli t0, zero, e16, mf2 \n\t" \
"vle16.v v9, (s2), v0.t \n\t" \
"vle16.v v11, (s5), v0.t \n\t" \
"vle16.v v13, (s3), v0.t \n\t" \
"vle16.v v15, (s4), v0.t \n\t" \
"vfwcvt.f.f.v v8, v9 \n\t" \
"vfwcvt.f.f.v v10, v11 \n\t" \
"vfwcvt.f.f.v v12, v13 \n\t" \
"vfwcvt.f.f.v v14, v15 \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vmv.v.v v9, v8 \n\t" \
"vmv.v.v v11, v10 \n\t" \
"vmv.v.v v13, v12 \n\t" \
"vmv.v.v v15, v14 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
"vfmul.vf v8, v8, f1 \n\t" \
"vfmul.vf v10, v10, f1 \n\t" \
"vfmul.vf v12, v12, f1 \n\t" \
"vfmul.vf v14, v14, f1 \n\t" \
"vfmul.vf v9, v9, f3 \n\t" \
"vfmul.vf v11, v11, f3 \n\t" \
"vfmul.vf v13, v13, f3 \n\t" \
"vfmul.vf v15, v15, f3 \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vfmul.vf v8, v8, f2, v0.t \n\t" \
"vfmul.vf v10, v10, f2, v0.t \n\t" \
"vfmul.vf v12, v12, f2, v0.t \n\t" \
"vfmul.vf v14, v14, f2, v0.t \n\t" \
"vfmul.vf v9, v9, f4, v0.t \n\t" \
"vfmul.vf v11, v11, f4, v0.t \n\t" \
"vfmul.vf v13, v13, f4, v0.t \n\t" \
"vfmul.vf v15, v15, f4, v0.t \n\t"
// [s2|s5, s3, s4, s6]
#define LOAD_SCALE_4x16 \
"addi s2, s5, -16 \n\t" \
"addi s3, s5, 16 \n\t" \
"addi s4, s5, 32 \n\t" \
"addi s6, s5, 48 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
"vle32.v v8, (s5) \n\t" \
"vle32.v v10, (s3) \n\t" \
"vle32.v v12, (s4) \n\t" \
"vle32.v v14, (s6) \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vle32.v v8, (s2), v0.t \n\t" \
"vle32.v v10, (s5), v0.t \n\t" \
"vle32.v v12, (s3), v0.t \n\t" \
"vle32.v v14, (s4), v0.t \n\t" \
"vmv.v.v v9, v8 \n\t" \
"vmv.v.v v11, v10 \n\t" \
"vmv.v.v v13, v12 \n\t" \
"vmv.v.v v15, v14 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
"vfmul.vf v8, v8, f1 \n\t" \
"vfmul.vf v10, v10, f1 \n\t" \
"vfmul.vf v12, v12, f1 \n\t" \
"vfmul.vf v14, v14, f1 \n\t" \
"vfmul.vf v9, v9, f3 \n\t" \
"vfmul.vf v11, v11, f3 \n\t" \
"vfmul.vf v13, v13, f3 \n\t" \
"vfmul.vf v15, v15, f3 \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vfmul.vf v8, v8, f2, v0.t \n\t" \
"vfmul.vf v10, v10, f2, v0.t \n\t" \
"vfmul.vf v12, v12, f2, v0.t \n\t" \
"vfmul.vf v14, v14, f2, v0.t \n\t" \
"vfmul.vf v9, v9, f4, v0.t \n\t" \
"vfmul.vf v11, v11, f4, v0.t \n\t" \
"vfmul.vf v13, v13, f4, v0.t \n\t" \
"vfmul.vf v15, v15, f4, v0.t \n\t"
//[s1| BIAS, s2, s3, s4]
#define LOAD_BIAS \
"vsetvli t0, zero, e32, mf2 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"addi s1, %[BIAS], -16 \n\t" \
"addi s2, %[BIAS], 16 \n\t" \
"addi s3, %[BIAS], 32 \n\t" \
"addi s4, %[BIAS], 48 \n\t" \
\
"vle32.v v24, (%[BIAS]) \n\t" \
"vle32.v v26, (s2) \n\t" \
"vle32.v v28, (s3) \n\t" \
"vle32.v v30, (s4) \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
"vle32.v v24, (s1), v0.t \n\t" \
"vle32.v v26, (%[BIAS]), v0.t \n\t" \
"vle32.v v28, (s2), v0.t \n\t" \
"vle32.v v30, (s3), v0.t \n\t" \
"vmv.v.v v25, v24 \n\t" \
"vmv.v.v v27, v26 \n\t" \
"vmv.v.v v29, v28 \n\t" \
"vmv.v.v v31, v30 \n\t"
#define SQ4BIT_KERNEL_COMP_4x16x16 \
"vmadot v16, v10, v2 \n\t" \
"vmadot v18, v10, v3 \n\t" \
"vmadot v20, v10, v4 \n\t" \
"vmadot v22, v10, v5 \n\t" \
"vmadot v16, v11, v6 \n\t" \
"vmadot v18, v11, v7 \n\t" \
"vmadot v20, v11, v8 \n\t" \
"vmadot v22, v11, v9 \n\t"
#define SAVE_RESULT_4x16 \
"addi a1, %[C], 0 \n\t" \
"add a2, %[C], %[LDC] \n\t" \
"add a3, a2, %[LDC] \n\t" \
"add a4, a3, %[LDC] \n\t" \
"addi a2, a2, -16 \n\t" \
"addi a4, a4, -16 \n\t" \
"li t1, 0xf0 \n\t" \
"vmv.s.x v0, t1 \n\t" \
"vsetvli t0, zero, e32, mf2 \n\t" \
\
"vse32.v v24, (a1) \n\t" \
"addi a1, a1, 16 \n\t" \
"vse32.v v25, (a3) \n\t" \
"addi a3, a3, 16 \n\t" \
\
"vse32.v v26, (a1) \n\t" \
"addi a1, a1, 16 \n\t" \
"vse32.v v27, (a3) \n\t" \
"addi a3, a3, 16 \n\t" \
\
"vse32.v v28, (a1) \n\t" \
"addi a1, a1, 16 \n\t" \
"vse32.v v29, (a3) \n\t" \
"addi a3, a3, 16 \n\t" \
\
"vse32.v v30, (a1) \n\t" \
"vse32.v v31, (a3) \n\t" \
"vsetvli t0, zero, e32, m1 \n\t" \
\
"vse32.v v24, (a2), v0.t \n\t" \
"addi a2, a2, 16 \n\t" \
"vse32.v v25, (a4), v0.t \n\t" \
"addi a4, a4, 16 \n\t" \
\
"vse32.v v26, (a2), v0.t \n\t" \
"addi a2, a2, 16 \n\t" \
"vse32.v v27, (a4), v0.t \n\t" \
"addi a4, a4, 16 \n\t" \
\
"vse32.v v28, (a2), v0.t \n\t" \
"addi a2, a2, 16 \n\t" \
"vse32.v v29, (a4), v0.t \n\t" \
"addi a4, a4, 16 \n\t" \
\
"vse32.v v30, (a2), v0.t \n\t" \
"vse32.v v31, (a4), v0.t \n\t"
#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
"vsetvli t0, zero, e8, mf2 \n\t" \
"vle8.v v11, (s6) \n\t" \
"vsetvli t0, zero, e8, m1 \n\t" \
"vrgather.vv v12, v11, v1 \n\t" \
"vadd.vi v1, v1, 4 \n\t" \
"vrgather.vv v13, v11, v1 \n\t" \
"vadd.vi v1, v1, 4 \n\t" \
"vrgather.vv v14, v11, v1 \n\t" \
"vadd.vi v1, v1, 4 \n\t" \
"vrgather.vv v15, v11, v1 \n\t" \
"vadd.vi v1, v1, -12 \n\t"
template <bool HasZeroPoint>
void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountN,
size_t BlockCountK,
const float * Bias,
const size_t ldc) {
GGML_UNUSED(QuantBScale);
GGML_UNUSED(QuantBZeroPoint);
size_t LDC = ldc * sizeof(float);
const size_t INNER = BlkLen / 16;
float tmp[4 * 16];
if constexpr (HasZeroPoint) {
for (size_t n = 0; n < CountN; n += 16) {
size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(uint8_t) + // zp
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
if (NBLKS < 16) {
CPtr = tmp;
LDC = 16 * sizeof(float);
}
if (Bias != nullptr) {
const float * bias = Bias + n;
if (NBLKS < 16) {
__asm__ volatile(
"vsetvli t0, %[N], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"vse32.v v0, (%[DST]) \n\t"
:
: [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
: "cc", "t0");
bias = tmp;
}
__asm__ volatile(LOAD_BIAS
"addi t3, %[BlockCountK], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"li s1, 24 \n\t"
"vmv.v.i v1, 3 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v1, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v1, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v1, 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
// scale offset
"addi s5, s1, 0 \n\t"
// zp offset
"addi s6, s1, 32 \n\t"
"addi s1, s6, 16 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vsub.vv v2, v2, v12 \n\t"
"vsub.vv v6, v6, v12 \n\t"
"vsub.vv v3, v3, v13 \n\t"
"vsub.vv v7, v7, v13 \n\t"
"vsub.vv v4, v4, v14 \n\t"
"vsub.vv v8, v8, v14 \n\t"
"vsub.vv v5, v5, v15 \n\t"
"vsub.vv v9, v9, v15 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16_FP16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
"s2", "s3", "s4", "s5", "s6");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t3, %[BlockCountK], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"li s1, 24 \n\t"
"vmv.v.i v1, 3 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v1, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v1, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v1, 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
// scale offset
"addi s5, s1, 0 \n\t"
// zp offset
"addi s6, s1, 32 \n\t"
"addi s1, s6, 16 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vsub.vv v2, v2, v12 \n\t"
"vsub.vv v6, v6, v12 \n\t"
"vsub.vv v3, v3, v13 \n\t"
"vsub.vv v7, v7, v13 \n\t"
"vsub.vv v4, v4, v14 \n\t"
"vsub.vv v8, v8, v14 \n\t"
"vsub.vv v5, v5, v15 \n\t"
"vsub.vv v9, v9, v15 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16_FP16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
"s4", "s5", "s6");
}
}
} else {
for (size_t n = 0; n < CountN; n += 16) {
size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
if (NBLKS < 16) {
CPtr = tmp;
LDC = 16 * sizeof(float);
}
if (Bias != nullptr) {
const float * bias = Bias + n;
if (NBLKS < 16) {
__asm__ volatile(
"vsetvli t0, %[N], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"vse32.v v0, (%[DST]) \n\t"
:
: [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
: "cc", "t0");
bias = tmp;
}
__asm__ volatile(LOAD_BIAS
"addi t3, %[BlockCountK], 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
"addi s5, s1, 0 \n\t"
"addi s1, s5, 32 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vsetvli t0, zero, e8, m1 \n\t"
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
"vadd.vi v8, v8, -8 \n\t"
"vadd.vi v9, v9, -8 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16_FP16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
"s2", "s3", "s4", "s5", "s6");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t3, %[BlockCountK], 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
"addi s5, s1, 0 \n\t"
"addi s1, s5, 32 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vsetvli t0, zero, e8, m1 \n\t"
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
"vadd.vi v8, v8, -8 \n\t"
"vadd.vi v9, v9, -8 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16_FP16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
"s4", "s5", "s6");
}
}
}
if (CountN % 16 != 0) {
// stroe output from tmp to C when NBLKS less than 16.
float * CPtr = C + CountN / 16 * 16;
const size_t N = CountN % 16;
LDC = ldc * sizeof(float);
__asm__ volatile(
"vsetvli t0, %[N], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi s2, %[SRC], 64 \n\t"
"addi s3, %[SRC], 64*2 \n\t"
"addi s4, %[SRC], 64*3 \n\t"
"vle32.v v2, (s2) \n\t"
"vle32.v v4, (s3) \n\t"
"vle32.v v6, (s4) \n\t"
"add t2, %[DST], %[LDC] \n\t"
"add t3, t2, %[LDC] \n\t"
"add t4, t3, %[LDC] \n\t"
"vse32.v v0, (%[DST]) \n\t"
"vse32.v v2, (t2) \n\t"
"vse32.v v4, (t3) \n\t"
"vse32.v v6, (t4) \n\t"
:
: [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
: "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
}
}
template <bool HasZeroPoint>
void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountN,
size_t BlockCountK,
const float * Bias,
const size_t ldc) {
GGML_UNUSED(QuantBScale);
GGML_UNUSED(QuantBZeroPoint);
size_t LDC = ldc * sizeof(float);
const size_t INNER = BlkLen / 16;
float tmp[4 * 16];
if constexpr (HasZeroPoint) {
for (size_t n = 0; n < CountN; n += 16) {
size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(uint8_t) + // zp
n * BlockCountK * sizeof(float); // scale
float * CPtr = C + n;
if (NBLKS < 16) {
CPtr = tmp;
LDC = 16 * sizeof(float);
}
if (Bias != nullptr) {
const float * bias = Bias + n;
if (NBLKS < 16) {
__asm__ volatile(
"vsetvli t0, %[N], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"vse32.v v0, (%[DST]) \n\t"
:
: [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
: "cc", "t0");
bias = tmp;
}
__asm__ volatile(LOAD_BIAS
"addi t3, %[BlockCountK], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"li s1, 24 \n\t"
"vmv.v.i v1, 3 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v1, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v1, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v1, 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
// scale offset
"addi s5, s1, 0 \n\t"
// zp offset
"addi s6, s1, 64 \n\t"
"addi s1, s6, 16 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vsub.vv v2, v2, v12 \n\t"
"vsub.vv v6, v6, v12 \n\t"
"vsub.vv v3, v3, v13 \n\t"
"vsub.vv v7, v7, v13 \n\t"
"vsub.vv v4, v4, v14 \n\t"
"vsub.vv v8, v8, v14 \n\t"
"vsub.vv v5, v5, v15 \n\t"
"vsub.vv v9, v9, v15 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
"s2", "s3", "s4", "s5", "s6");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t3, %[BlockCountK], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"li s1, 24 \n\t"
"vmv.v.i v1, 3 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v1, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v1, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v1, 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
// scale offset
"addi s5, s1, 0 \n\t"
// zp offset
"addi s6, s1, 64 \n\t"
"addi s1, s6, 16 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vsub.vv v2, v2, v12 \n\t"
"vsub.vv v6, v6, v12 \n\t"
"vsub.vv v3, v3, v13 \n\t"
"vsub.vv v7, v7, v13 \n\t"
"vsub.vv v4, v4, v14 \n\t"
"vsub.vv v8, v8, v14 \n\t"
"vsub.vv v5, v5, v15 \n\t"
"vsub.vv v9, v9, v15 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
"s4", "s5", "s6");
}
}
} else {
for (size_t n = 0; n < CountN; n += 16) {
size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(float); // scale
float * CPtr = C + n;
if (NBLKS < 16) {
CPtr = tmp;
LDC = 16 * sizeof(float);
}
if (Bias != nullptr) {
const float * bias = Bias + n;
if (NBLKS < 16) {
__asm__ volatile(
"vsetvli t0, %[N], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"vse32.v v0, (%[DST]) \n\t"
:
: [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
: "cc", "t0");
bias = tmp;
}
__asm__ volatile(LOAD_BIAS
"addi t3, %[BlockCountK], 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
"addi s5, s1, 0 \n\t"
"addi s1, s5, 64 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vsetvli t0, zero, e8, m1 \n\t"
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
"vadd.vi v8, v8, -8 \n\t"
"vadd.vi v9, v9, -8 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
"s2", "s3", "s4", "s5", "s6");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v24, v24, v24 \n\t"
"addi t3, %[BlockCountK], 0 \n\t"
"addi a1, %[A], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"BLOCK_COUNTK_LOOP%=: \n\t"
"addi s5, s1, 0 \n\t"
"addi s1, s5, 64 \n\t"
"addi s2, s1, 32 \n\t"
"addi s3, s1, 32*2 \n\t"
"addi s4, s1, 32*3 \n\t"
"vsetvli t0, zero, e32, m8 \n\t"
"vxor.vv v16, v16, v16 \n\t"
// load a scale
"flw f1, (a1) \n\t"
"flw f2, 4(a1) \n\t"
"flw f3, 8(a1) \n\t"
"flw f4, 12(a1) \n\t"
"addi a1, a1, 16 \n\t"
"addi t2, %[INNER], 0 \n\t"
"BLOCK_INNER_LOOP%=: \n\t"
LOAD_B_16x8x2
"vsetvli t0, zero, e8, m1 \n\t"
"vle8.v v10, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vle8.v v11, (a1) \n\t"
"addi a1, a1, 32 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
"vadd.vi v8, v8, -8 \n\t"
"vadd.vi v9, v9, -8 \n\t"
SQ4BIT_KERNEL_COMP_4x16x16
"addi t2, t2, -1 \n\t"
"bnez t2, BLOCK_INNER_LOOP%= \n\t"
LOAD_SCALE_4x16
"vsetvli t0, zero, e32, m8 \n\t"
"vfcvt.f.x.v v16, v16 \n\t"
"vfmacc.vv v24, v16, v8 \n\t"
"addi t3, t3, -1 \n\t"
"bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
"RESULT_SAVE%=: \n\t"
SAVE_RESULT_4x16
:
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
[BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
: "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
"s4", "s5", "s6");
}
}
}
if (CountN % 16 != 0) {
// stroe output from tmp to C when NBLKS less than 16.
float * CPtr = C + CountN / 16 * 16;
const size_t N = CountN % 16;
LDC = ldc * sizeof(float);
__asm__ volatile(
"vsetvli t0, %[N], e32, m2 \n\t"
"vle32.v v0, (%[SRC]) \n\t"
"addi s2, %[SRC], 64 \n\t"
"addi s3, %[SRC], 64*2 \n\t"
"addi s4, %[SRC], 64*3 \n\t"
"vle32.v v2, (s2) \n\t"
"vle32.v v4, (s3) \n\t"
"vle32.v v6, (s4) \n\t"
"add t2, %[DST], %[LDC] \n\t"
"add t3, t2, %[LDC] \n\t"
"add t4, t3, %[LDC] \n\t"
"vse32.v v0, (%[DST]) \n\t"
"vse32.v v2, (t2) \n\t"
"vse32.v v4, (t3) \n\t"
"vse32.v v6, (t4) \n\t"
:
: [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
: "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
}
}
template <bool HasZeroPoint>
void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountN,
size_t BlockCountK,
const float * Bias) {
GGML_UNUSED(QuantBScale);
GGML_UNUSED(QuantBZeroPoint);
size_t INNER = BlkLen / 16;
if constexpr (HasZeroPoint) {
for (size_t n = 0; n < CountN; n += 16) {
size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(uint8_t) + // zp
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
size_t cnt = BlockCountK;
if (Bias != nullptr) {
const float * bias = Bias + n;
__asm__ volatile(
"addi t3, %[NBLKS], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vmv.v.i v13, 3 \n\t"
"li s1, 24 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v13, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v13, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v13, 0 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 8 \n\t"
"addi s3, %[B], 16 \n\t"
"addi s4, %[B], 24 \n\t"
// zp offset
"addi s7, %[B], 32 \n\t"
// a offset
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v28, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v29, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v30, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v31, (%[BIAS]) \n\t"
"LOOP_K%=: \n\t"
"vsetvli t0, zero, e16, mf4 \n\t"
"vle16.v v4, (s1) \n\t"
"addi s1, s1, 48 \n\t"
"vle16.v v5, (s2) \n\t"
"addi s2, s2, 72 \n\t"
"vle16.v v6, (s3) \n\t"
"addi s3, s3, 96 \n\t"
"vle16.v v7, (s4) \n\t"
"addi s4, s4, 120 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"vfwcvt.f.f.v v8, v4 \n\t"
"vfwcvt.f.f.v v9, v5 \n\t"
"vfwcvt.f.f.v v10, v6 \n\t"
"vfwcvt.f.f.v v11, v7 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vsub.vv v0, v0, v8 \n\t"
"vsub.vv v4, v4, v8 \n\t"
"vsub.vv v1, v1, v9 \n\t"
"vsub.vv v5, v5, v9 \n\t"
"vsub.vv v2, v2, v10 \n\t"
"vsub.vv v6, v6, v10 \n\t"
"vsub.vv v3, v3, v11 \n\t"
"vsub.vv v7, v7, v11 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_F16_1X4X4
"addi s7, s1, 32 \n\t"
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vmv.v.i v13, 3 \n\t"
"li s1, 24 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v13, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v13, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v13, 0 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 8 \n\t"
"addi s3, %[B], 16 \n\t"
"addi s4, %[B], 24 \n\t"
"addi s7, %[B], 32 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"LOOP_K%=: \n\t"
"vsetvli t0, zero, e16, mf4 \n\t"
"vle16.v v4, (s1) \n\t"
"addi s1, s1, 48 \n\t"
"vle16.v v5, (s2) \n\t"
"addi s2, s2, 72 \n\t"
"vle16.v v6, (s3) \n\t"
"addi s3, s3, 96 \n\t"
"vle16.v v7, (s4) \n\t"
"addi s4, s4, 120 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"vfwcvt.f.f.v v8, v4 \n\t"
"vfwcvt.f.f.v v9, v5 \n\t"
"vfwcvt.f.f.v v10, v6 \n\t"
"vfwcvt.f.f.v v11, v7 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vsub.vv v0, v0, v8 \n\t"
"vsub.vv v4, v4, v8 \n\t"
"vsub.vv v1, v1, v9 \n\t"
"vsub.vv v5, v5, v9 \n\t"
"vsub.vv v2, v2, v10 \n\t"
"vsub.vv v6, v6, v10 \n\t"
"vsub.vv v3, v3, v11 \n\t"
"vsub.vv v7, v7, v11 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_F16_1X4X4
"addi s7, s1, 32 \n\t"
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
}
}
} else {
for (size_t n = 0; n < CountN; n += 16) {
size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(_Float16); // scale
float * CPtr = C + n;
size_t cnt = BlockCountK;
if (Bias != nullptr) {
const float * bias = Bias + n;
__asm__ volatile(
"addi t3, %[NBLKS], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 8 \n\t"
"addi s3, %[B], 16 \n\t"
"addi s4, %[B], 24 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v28, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v29, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v30, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v31, (%[BIAS]) \n\t"
"LOOP_K%=: \n\t"
"vsetvli t0, zero, e16, mf4 \n\t"
"vle16.v v4, (s1) \n\t"
"addi s1, s1, 32 \n\t"
"vle16.v v5, (s2) \n\t"
"addi s2, s2, 56 \n\t"
"vle16.v v6, (s3) \n\t"
"addi s3, s3, 80 \n\t"
"vle16.v v7, (s4) \n\t"
"addi s4, s4, 104 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"vfwcvt.f.f.v v8, v4 \n\t"
"vfwcvt.f.f.v v9, v5 \n\t"
"vfwcvt.f.f.v v10, v6 \n\t"
"vfwcvt.f.f.v v11, v7 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vadd.vi v0, v0, -8 \n\t"
"vadd.vi v1, v1, -8 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_F16_1X4X4
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 8 \n\t"
"addi s3, %[B], 16 \n\t"
"addi s4, %[B], 24 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"LOOP_K%=: \n\t"
"vsetvli t0, zero, e16, mf4 \n\t"
"vle16.v v4, (s1) \n\t"
"addi s1, s1, 32 \n\t"
"vle16.v v5, (s2) \n\t"
"addi s2, s2, 56 \n\t"
"vle16.v v6, (s3) \n\t"
"addi s3, s3, 80 \n\t"
"vle16.v v7, (s4) \n\t"
"addi s4, s4, 104 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"vfwcvt.f.f.v v8, v4 \n\t"
"vfwcvt.f.f.v v9, v5 \n\t"
"vfwcvt.f.f.v v10, v6 \n\t"
"vfwcvt.f.f.v v11, v7 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vadd.vi v0, v0, -8 \n\t"
"vadd.vi v1, v1, -8 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_F16_1X4X4
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
}
}
}
}
template <bool HasZeroPoint>
void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountN,
size_t BlockCountK,
const float * Bias) {
GGML_UNUSED(QuantBScale);
GGML_UNUSED(QuantBZeroPoint);
const size_t INNER = BlkLen / 16;
if constexpr (HasZeroPoint) {
for (size_t n = 0; n < CountN; n += 16) {
size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(uint8_t) + // zp
n * BlockCountK * sizeof(float); // scale
float * CPtr = C + n;
size_t cnt = BlockCountK;
if (Bias != nullptr) {
const float * bias = Bias + n;
__asm__ volatile(
"addi t3, %[NBLKS], 0 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vmv.v.i v13, 3 \n\t"
"li s1, 24 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v13, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v13, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v13, 0 \n\t"
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
// scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 16 \n\t"
"addi s3, %[B], 32 \n\t"
"addi s4, %[B], 48 \n\t"
// zp offset
"addi s7, %[B], 64 \n\t"
// a offset
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v28, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v29, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v30, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v31, (%[BIAS]) \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"LOOP_K%=: \n\t"
// load scale
"vle32.v v8, (s1) \n\t"
"addi s1, s1, 80 \n\t"
"vle32.v v9, (s2) \n\t"
"addi s2, s2, 96 \n\t"
"vle32.v v10, (s3) \n\t"
"addi s3, s3, 112 \n\t"
"vle32.v v11, (s4) \n\t"
"addi s4, s4, 128 \n\t"
// load a scale
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
// a scale * b scale
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vsub.vv v0, v0, v8 \n\t"
"vsub.vv v4, v4, v8 \n\t"
"vsub.vv v1, v1, v9 \n\t"
"vsub.vv v5, v5, v9 \n\t"
"vsub.vv v2, v2, v10 \n\t"
"vsub.vv v6, v6, v10 \n\t"
"vsub.vv v3, v3, v11 \n\t"
"vsub.vv v7, v7, v11 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_1X4X4
"addi s7, s1, 64 \n\t"
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"vmv.v.i v13, 3 \n\t"
"li s1, 24 \n\t"
"vsetvli t0, s1, e8, m1 \n\t"
"vmv.v.i v13, 2 \n\t"
"vsetvli t0, zero, e8, mf2 \n\t"
"vmv.v.i v13, 1 \n\t"
"vsetvli t0, zero, e8, mf4 \n\t"
"vmv.v.i v13, 0 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 16 \n\t"
"addi s3, %[B], 32 \n\t"
"addi s4, %[B], 48 \n\t"
"addi s7, %[B], 64 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"LOOP_K%=: \n\t"
"vle32.v v8, (s1) \n\t"
"addi s1, s1, 80 \n\t"
"vle32.v v9, (s2) \n\t"
"addi s2, s2, 96 \n\t"
"vle32.v v10, (s3) \n\t"
"addi s3, s3, 112 \n\t"
"vle32.v v11, (s4) \n\t"
"addi s4, s4, 128 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
SQ4BIT_KERNEL_LOAD_ZP_16X1
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vsub.vv v0, v0, v8 \n\t"
"vsub.vv v4, v4, v8 \n\t"
"vsub.vv v1, v1, v9 \n\t"
"vsub.vv v5, v5, v9 \n\t"
"vsub.vv v2, v2, v10 \n\t"
"vsub.vv v6, v6, v10 \n\t"
"vsub.vv v3, v3, v11 \n\t"
"vsub.vv v7, v7, v11 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_1X4X4
"addi s7, s1, 64 \n\t"
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
}
}
} else {
for (size_t n = 0; n < CountN; n += 16) {
size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
n * BlockCountK * BlkLen / 2 + // b data
n * BlockCountK * sizeof(float); // scale
float * CPtr = C + n;
size_t cnt = BlockCountK;
if (Bias != nullptr) {
const float * bias = Bias + n;
__asm__ volatile(
"addi t3, %[NBLKS], 0 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 16 \n\t"
"addi s3, %[B], 32 \n\t"
"addi s4, %[B], 48 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v28, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v29, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v30, (%[BIAS]) \n\t"
"sub t3, t3, t0 \n\t"
"addi %[BIAS], %[BIAS], 16 \n\t"
"vsetvli t0, t3, e32, mf2 \n\t"
"vle32.v v31, (%[BIAS]) \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"LOOP_K%=: \n\t"
"vle32.v v8, (s1) \n\t"
"addi s1, s1, 64 \n\t"
"vle32.v v9, (s2) \n\t"
"addi s2, s2, 80 \n\t"
"vle32.v v10, (s3) \n\t"
"addi s3, s3, 96 \n\t"
"vle32.v v11, (s4) \n\t"
"addi s4, s4, 112 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vadd.vi v0, v0, -8 \n\t"
"vadd.vi v1, v1, -8 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_1X4X4
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
} else {
__asm__ volatile(
"vsetvli t0, zero, e32, m4 \n\t"
"vxor.vv v28, v28, v28 \n\t"
"addi s1, %[B], 0 \n\t"
"addi s2, %[B], 16 \n\t"
"addi s3, %[B], 32 \n\t"
"addi s4, %[B], 48 \n\t"
"addi s5, %[A], 0 \n\t"
"addi s6, %[A], 12 \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
"LOOP_K%=: \n\t"
"vle32.v v8, (s1) \n\t"
"addi s1, s1, 64 \n\t"
"vle32.v v9, (s2) \n\t"
"addi s2, s2, 80 \n\t"
"vle32.v v10, (s3) \n\t"
"addi s3, s3, 96 \n\t"
"vle32.v v11, (s4) \n\t"
"addi s4, s4, 112 \n\t"
"flw f1, (s5) \n\t"
"addi s5, s5, 4 \n\t"
"addi t5, %[INNER], 0 \n\t"
"vxor.vv v16, v16, v16 \n\t"
"vxor.vv v18, v18, v18 \n\t"
"vxor.vv v20, v20, v20 \n\t"
"vxor.vv v22, v22, v22 \n\t"
"vfmul.vf v24, v8, f1 \n\t"
"vfmul.vf v25, v9, f1 \n\t"
"vfmul.vf v26, v10, f1 \n\t"
"vfmul.vf v27, v11, f1 \n\t"
"addi %[CNT], %[CNT], -1 \n\t"
"vsetvli t0, zero, e8, m1 \n\t"
"LOOP_INNER%=: \n\t"
SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
"vadd.vi v0, v0, -8 \n\t"
"vadd.vi v1, v1, -8 \n\t"
"vadd.vi v2, v2, -8 \n\t"
"vadd.vi v3, v3, -8 \n\t"
"vadd.vi v4, v4, -8 \n\t"
"vadd.vi v5, v5, -8 \n\t"
"vadd.vi v6, v6, -8 \n\t"
"vadd.vi v7, v7, -8 \n\t"
SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
"bnez t5, LOOP_INNER%= \n\t"
"vsetvli t0, zero, e32, mf2 \n\t"
SQ4BIT_KERNEL_ACC_1X4X4
"bnez %[CNT], LOOP_K%= \n\t"
"addi t3, zero, 16 \n\t"
"addi s1, %[C], 16 \n\t"
"addi s2, %[C], 32 \n\t"
"addi s3, %[C], 48 \n\t"
"blt %[NBLKS], t3, ST_TAIL%= \n\t"
"vse32.v v28, (%[C]) \n\t"
"vse32.v v29, (s1) \n\t"
"vse32.v v30, (s2) \n\t"
"vse32.v v31, (s3) \n\t"
"jal x0, END%= \n\t"
"ST_TAIL%=: \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v28, (%[C]) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v29, (s1) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v30, (s2) \n\t"
"vsetvli t0, %[NBLKS], e32, mf2 \n\t"
"sub %[NBLKS], %[NBLKS], t0 \n\t"
"vse32.v v31, (s3) \n\t"
"END%=: \n\t"
: [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
: [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
: "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
}
}
}
}
template <bool HasZeroPoint>
inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountM,
size_t CountN,
size_t BlockStrideQuantB,
const float * Bias,
const size_t ldc,
const size_t scalestride) {
if (scalestride == 4) {
SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
CountN, BlockStrideQuantB, Bias, ldc);
} else if (scalestride == 2) {
SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
}
}
template <bool HasZeroPoint>
inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountM,
size_t CountN,
size_t BlockStrideQuantB,
const float * Bias,
const size_t ldc,
const size_t scalestride) {
if (scalestride == 4) {
SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
CountN, BlockStrideQuantB, Bias);
} else if (scalestride == 2) {
SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
}
}
} // namespace
namespace ime1 {
size_t gemm_kernel_i8i4(size_t BlkLen,
const std::byte * QuantA,
const std::byte * QuantBData,
const float * QuantBScale,
const std::byte * QuantBZeroPoint,
float * C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t BlockCountK,
size_t ldc,
const float * Bias,
const size_t ScaleStride) {
GGML_UNUSED(CountM);
GGML_UNUSED(CountK);
GGML_UNUSED(ldc);
if (CountM >= 4) {
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
} else {
SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
ldc, ScaleStride);
}
return 4;
} else {
if (QuantBZeroPoint != nullptr) {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
} else {
SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
ldc, ScaleStride);
}
return 1;
}
}
} // namespace ime1
} // namespace sqnbitgemm_spacemit_ime