hexagon: support for IQ4_NL and MXFP4 (#21018)
* ggml-hexagon: add IQ4_NL and MXFP4 HMX matmul support - Add IQ4_NL quantization type support to Hexagon backend (buffer set/get tensor repack, mul_mat, mul_mat_id dispatch) - Implement HVX IQ4_NL vec_dot kernels (1x1, 2x1, 2x2) with LUT-based 4-bit index to int8 kvalue dequantization - Add MXFP4 HMX dequantization path with E8M0 scale conversion, including batch-4 fast path and single-tile fallback - Unify quantized row size / scale offset logic to handle Q4_0, Q8_0, IQ4_NL, and MXFP4 in the DMA fetch path * ggml-hexagon: fix SKIP_QUANTIZE src1 address mismatch in mixed-quant models * Fix the pragma indent
This commit is contained in:
parent
e6f6770515
commit
ee051c1e4e
|
|
@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
repack_q8_0_q8x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
// IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16])
|
||||
repack_q4_0_q4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
|
|
@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
repack_q8x4x2_q8_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4x4x2_q4_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
|
|
@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
if (src0->ne[0] % 32) {
|
||||
return false;
|
||||
|
|
@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
if ((src0->ne[0] % 32)) {
|
||||
return false;
|
||||
|
|
@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
|
|||
delete backend;
|
||||
}
|
||||
|
||||
// Map weight type to its activation quantization family.
|
||||
// Types in the same family produce identical Q8 formats in VTCM and can
|
||||
// safely share quantized activation data via SKIP_QUANTIZE.
|
||||
// When adding a new quantized type, assign it the correct family here.
|
||||
static inline int act_quant_family(enum ggml_type wtype) {
|
||||
switch (wtype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return 1; // Q8x4x2
|
||||
default:
|
||||
return 0; // unknown / not quantized
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
|
||||
return (op0 && op0->src[1] == op1->src[1] &&
|
||||
act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) &&
|
||||
act_quant_family(op0->src[0]->type) != 0);
|
||||
}
|
||||
|
||||
static inline bool is_compute_op(ggml_tensor *node)
|
||||
|
|
@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
|||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
|
||||
const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
|||
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
|
||||
};
|
||||
|
||||
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
|
||||
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
|
||||
};
|
||||
|
||||
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
|
||||
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
|
||||
|
|
@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned
|
|||
|
||||
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
|
||||
#define HMX_X4X2_SCALES_PER_BLK 8
|
||||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes
|
||||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
|
||||
#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
|
||||
|
||||
static inline void swap_ptr(void **p1, void **p2) {
|
||||
void *t = *p1;
|
||||
|
|
@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
|
|||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
case HTP_TYPE_Q8_0:
|
||||
return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
case HTP_TYPE_MXFP4:
|
||||
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
|
|||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
// --- MXFP4 E8M0 scale conversion and dequantization ---
|
||||
//
|
||||
// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
|
||||
// Scalar loads from the stack array execute on the scalar pipeline, in parallel
|
||||
// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
|
||||
// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
|
||||
// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
|
||||
|
||||
typedef struct {
|
||||
__fp16 v[8] __attribute__((aligned(16)));
|
||||
} mxfp4_scales_t;
|
||||
|
||||
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
|
||||
mxfp4_scales_t s;
|
||||
HVX_Vector v = hvx_vmemu(e8m0_8);
|
||||
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
|
||||
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
|
||||
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
|
||||
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
|
||||
vh = Q6_Vh_vasl_VhR(vh, 10);
|
||||
hvx_vec_store_u(s.v, 16, vh);
|
||||
return s;
|
||||
}
|
||||
|
||||
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
|
||||
return hvx_vec_splat_f16(scales.v[idx]);
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
|
||||
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
|
||||
bool upper_nibbles,
|
||||
int sub_blk,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
|
||||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
|
||||
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
bool upper_nibbles,
|
||||
int sub_blk_base,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales,
|
||||
HVX_Vector out[4]) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
|
||||
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
|
||||
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
|
||||
mxfp4_extract_splat(scales, sub_blk_base + 1));
|
||||
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
|
||||
mxfp4_extract_splat(scales, sub_blk_base + 3));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
out[0] = v_lo;
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64);
|
||||
out[2] = v_hi;
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64);
|
||||
}
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
|
||||
// Output: vtcm_dst in tile-major FP16 layout.
|
||||
|
|
@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int qrow_size = is_q4 ? (k_block / 2) : k_block;
|
||||
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL)
|
||||
? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut);
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
hvx_vmem(q4_0_to_fp16_lut);
|
||||
|
||||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||||
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
|
||||
|
|
@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
int ct = t / n_k_tiles; // column tile index
|
||||
int kt = t % n_k_tiles; // K tile index
|
||||
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
|
||||
((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
|
|
@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
continue;
|
||||
}
|
||||
|
||||
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
|
||||
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
|
||||
|
||||
__fp16 * tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) {
|
||||
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
(void) *(volatile HVX_Vector *) (tile_bases[g]);
|
||||
}
|
||||
|
||||
t += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (is_q4) {
|
||||
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
|
|
@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
|||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
|
||||
HVX_Vector v1;
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
v1 = Q6_V_vzero();
|
||||
}
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *) (tile_base);
|
||||
} else {
|
||||
// Q8_0
|
||||
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
|
||||
|
|
@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
|||
{
|
||||
qweight_fetch_task_state_t s;
|
||||
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int blk_start = kk / QK_Q4_0x4x2;
|
||||
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||||
const int full_qrow = is_q4 ? (k / 2) : k;
|
||||
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
|
||||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||||
const int scale_blk_size =
|
||||
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
|
||||
|
||||
s.dst = vtcm_scratch0;
|
||||
s.src = w + nc * row_stride;
|
||||
s.n_rows = n_blk_sz;
|
||||
s.src_stride = row_stride;
|
||||
s.dst_stride = sub_row_stride;
|
||||
s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2);
|
||||
s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2);
|
||||
s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE;
|
||||
s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
|
||||
s.quant_off =
|
||||
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
|
||||
s.quant_width =
|
||||
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
|
||||
s.scale_off = full_qrow + blk_start * scale_blk_size;
|
||||
s.scale_width = nb_sub * scale_blk_size;
|
||||
|
||||
// 2D DMA: quants sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
|
||||
|
|
|
|||
|
|
@ -31,6 +31,12 @@ struct htp_context {
|
|||
|
||||
uint32_t opmask;
|
||||
|
||||
// Cached src1 spad position from the last quantize pass.
|
||||
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
|
||||
// at this address; the matmul must read from here instead of recomputing
|
||||
// the offset (which depends on the current op's src0 size).
|
||||
uint8_t * prev_src1_spad;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
|
|
|
|||
|
|
@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx,
|
|||
return;
|
||||
}
|
||||
|
||||
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
|
||||
// Other types (e.g. MXFP4) fall back to HVX.
|
||||
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
|
||||
// Other types fall back to HVX.
|
||||
{
|
||||
uint32_t wtype = req->src0.type;
|
||||
if (wtype != HTP_TYPE_F16 &&
|
||||
wtype != HTP_TYPE_Q4_0 &&
|
||||
wtype != HTP_TYPE_Q8_0 &&
|
||||
wtype != HTP_TYPE_IQ4_NL) {
|
||||
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL &&
|
||||
wtype != HTP_TYPE_MXFP4) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,6 +60,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
|||
0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
|
||||
};
|
||||
|
||||
// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue
|
||||
// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
||||
static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = {
|
||||
0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0,
|
||||
0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
||||
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
|
||||
0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
|
|
@ -68,6 +78,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
|||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) {
|
||||
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
||||
|
||||
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
||||
HVX_Vector v2_3 = vptr[1]; // ...
|
||||
HVX_Vector v4_5 = vptr[2]; // ...
|
||||
HVX_Vector v6_7 = vptr[3]; // ...
|
||||
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
||||
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
|
||||
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
|
||||
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
|
||||
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
|
||||
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
|
||||
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
|
||||
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
||||
|
||||
v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
||||
v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
||||
v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
|
||||
v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
|
||||
v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
|
||||
v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
|
||||
v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
|
||||
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
|
||||
|
||||
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
||||
return r;
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
||||
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2; // 256
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
||||
|
||||
HVX_Vector_x8 r;
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nb; i++) {
|
||||
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
||||
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
||||
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
||||
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
||||
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
|
||||
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
|
||||
|
||||
static inline size_t q8x4x2_row_size(uint32_t ne) {
|
||||
|
|
@ -921,6 +998,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
|||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
||||
}
|
||||
|
||||
// ======== IQ4_NL x Q8_0 vec_dot kernels ========
|
||||
// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue).
|
||||
// Scale format is identical to Q4_0 (fp16 scales).
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n,
|
||||
float * restrict s0,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vy0) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
||||
|
||||
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
HVX_Vector r0_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
||||
|
||||
hvx_vec_store_u(s0, 4, r0_sum);
|
||||
}
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n,
|
||||
float * restrict s0,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vx1 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
||||
|
||||
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
HVX_Vector r0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
||||
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
||||
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n,
|
||||
float * restrict s0,
|
||||
float * restrict s1,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vx1,
|
||||
const void * restrict vy0,
|
||||
const void * restrict vy1) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vx1 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
assert((unsigned long) vy1 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
|
||||
|
||||
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;
|
||||
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;
|
||||
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;
|
||||
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
||||
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
||||
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
||||
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
||||
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
||||
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
||||
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
||||
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);
|
||||
}
|
||||
|
||||
static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
||||
assert(n % 32 == 0); // min sub-block size
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
|
|
@ -2393,6 +2757,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t
|
|||
mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
|
||||
mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
|
||||
return 0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
mmctx->type = "iq4nlx4x2-f32";
|
||||
mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1;
|
||||
mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1;
|
||||
mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2;
|
||||
return 0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
mmctx->type = "mxfp4x4x2-f32";
|
||||
mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
|
||||
|
|
@ -2556,6 +2926,13 @@ int op_matmul(struct htp_ops_context * octx) {
|
|||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
// Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
// SKIP_QUANTIZE: Q8 data lives at the address written by the previous
|
||||
// quantize pass. The current op may have a different src0 size (e.g.
|
||||
// IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong.
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
|
|
@ -2659,6 +3036,9 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
|||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue