refactoring: improve code formatting and alignment in matmul operations

This commit is contained in:
chraac 2025-12-19 20:55:49 +08:00
parent 7ef467ce20
commit e0b1435b50
1 changed files with 39 additions and 59 deletions

View File

@ -485,8 +485,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
}
// Convert into fp32 and reduce
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
@ -658,8 +658,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
}
// Convert into fp32 and reduce
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
@ -900,8 +900,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
}
// Convert into fp32 and reduce
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
@ -909,18 +909,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
#if 1
static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
if (0) {
float rsum = 0;
const __fp16 * restrict vx = (const __fp16 * restrict) x;
const float * restrict vy = (const float * restrict) y;
for (uint32_t i = 0; i < n; i++) {
rsum += (float)vx[i] * vy[i];
}
*s = rsum;
return;
}
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
@ -929,12 +917,10 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
// for some reason we need volatile here so that the compiler doesn't try anything funky
volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
float r_sum_scalar = 0.0f;
uint32_t i = 0;
uint32_t i = 0;
for (i = 0; i < nv0; i++) {
HVX_VectorPair yp = vy[i];
HVX_Vector x = vx[i];
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
@ -948,43 +934,37 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
}
if (nv1) {
// HVX_VectorPair yp = vy[i];
HVX_VectorPair yp = vy[i];
HVX_Vector x = vx[i];
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
// HVX_Vector x = vx[i];
// HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
HVX_Vector l_x;
HVX_Vector l_y;
if (nv1 >= 32) {
volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
nv1 -= 32;
l_x = Q6_V_hi_W(xp);
l_y = Q6_V_hi_W(yp);
} else {
l_x = Q6_V_lo_W(xp);
l_y = Q6_V_lo_W(yp);
}
// if (nv1 >= 32) {
// volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
// rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
// nv1 -= 32;
// }
// rsum = hvx_vec_qf32_reduce_sum(rsum);
// if (nv1) {
// volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
// HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
// rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
// }
//process the remainder using scalar loop
rsum = hvx_vec_qf32_reduce_sum(rsum);
const __fp16 * restrict sx = (const __fp16 * restrict) x;
const float * restrict sy = (const float * restrict) y;
for (uint32_t i = nv0 * 64; i < n; i++) {
r_sum_scalar += (float) sx[i] * sy[i];
if (nv1) {
volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(l_x), l_y);
HVX_Vector sum = Q6_V_valign_VVR(lo, Q6_V_vzero(), nv1 * sizeof(float));
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
}
// hvx_vec_dump_fp16("X", x);
// hvx_vec_dump_fp16("Y", y);
// hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum));
// hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum));
} else {
rsum = hvx_vec_qf32_reduce_sum(rsum);
}
*s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
rsum = hvx_vec_qf32_reduce_sum(rsum);
*s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum));
# ifdef HTP_DEBUG
{
@ -1120,8 +1100,8 @@ static void matmul(struct htp_matmul_type * mt,
const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
// Prefill spad with src0 rows
#pragma unroll(4)
// Prefill spad with src0 rows
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
if (is0 >= HTP_SPAD_SRC0_NROWS) {
@ -1135,7 +1115,7 @@ static void matmul(struct htp_matmul_type * mt,
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
#pragma unroll(2)
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
@ -1159,7 +1139,7 @@ static void matmul(struct htp_matmul_type * mt,
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
#pragma unroll(2)
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
@ -1222,8 +1202,8 @@ static void matvec(struct htp_matmul_type * mt,
const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
float * restrict dst_col = (float *) dst->data;
// Prefill spad with 2x src0 rows
#pragma unroll(2)
// Prefill spad with 2x src0 rows
#pragma unroll(2)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint32_t is0 = (ir0 - src0_start_row);
if (is0 >= HTP_SPAD_SRC0_NROWS) {
@ -1336,8 +1316,8 @@ static void matmul_id(struct htp_matmul_type * mt,
const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
// Prefill spad with src0 rows
#pragma unroll(4)
// Prefill spad with src0 rows
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
if (is0 >= HTP_SPAD_SRC0_NROWS) {
@ -1460,8 +1440,8 @@ static void matvec_id(struct htp_matmul_type * mt,
const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
float * restrict dst_row = (float *) (dst->data + ie1 * nb1);
// Prefill spad with src0 rows
#pragma unroll(4)
// Prefill spad with src0 rows
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
if (is0 >= HTP_SPAD_SRC0_NROWS) {
@ -2347,7 +2327,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping){ id, iid1 };
matrix_row_counts[i02] += 1;
}
}