ggml-hexagon: flash-attention and reduce-sum optimizations (#19141)
* wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * hexagon: optimize reduce-sum for v75+ * hexagon: always keep row_sums in sf/fp32 * ggml-hexagon: enhance directory checks for HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT * fix compiling error after rebase --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
parent
3dd95914d0
commit
89f10baad5
|
|
@ -1,8 +1,20 @@
|
|||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
|
||||
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
|
||||
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
|
||||
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
|
||||
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
|
||||
endif()
|
||||
|
||||
if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
|
||||
message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
|
||||
file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
|
||||
string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
|
||||
message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
|
||||
set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
|
||||
message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@
|
|||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
|
|
@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
|
|||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
|
@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
|
|||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
|
||||
const void * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
unsigned int n,
|
||||
float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
|
||||
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
|
|
@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
|||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
|
|
@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
|||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
const void * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
unsigned int n,
|
||||
float s) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
|
|
@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
|
|
|
|||
|
|
@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
|
|||
}
|
||||
|
||||
static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
float d[32];
|
||||
} u = { .v = v };
|
||||
HVX_VectorAlias u = { .v = v };
|
||||
|
||||
const uint32_t n0 = n / 16;
|
||||
const uint32_t n1 = n % 16;
|
||||
int i = 0;
|
||||
for (; i < n0; i++) {
|
||||
hex_dump_f32_line(pref, u.d + (16 * i), 16);
|
||||
hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
hex_dump_f32_line(pref, u.d + (16 * i), n1);
|
||||
hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
|
|||
return hvx_vec_reduce_sum_n_qf32(in, 32);
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ > 75
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
|
||||
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
|
||||
HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
|
||||
|
||||
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
|
||||
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
|
||||
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
|
||||
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
|
||||
return sum_sf;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||
sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
|
||||
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
|
||||
HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
|
||||
|
||||
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
|
||||
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
|
||||
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
|
||||
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
|
||||
return Q6_Vsf_equals_Vqf32(sum_qf);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // fp32 nbytes
|
||||
|
|
@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n)
|
|||
return sum;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
|
||||
return hvx_vec_reduce_sum_n_f32(in, 32);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hvx-dump.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
|
@ -320,7 +321,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
|
||||
|
||||
// Row sum (qf32)
|
||||
// Row sum (sf)
|
||||
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
|
||||
|
||||
// Multiply and accumulate into int32.
|
||||
|
|
@ -344,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
|
||||
|
|
@ -362,14 +363,14 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
// Zero out unused scales
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
// Reduce and convert into fp32
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
||||
|
||||
hvx_vec_store_u(&s[0], 4, r0_sum);
|
||||
}
|
||||
|
|
@ -402,7 +403,7 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|||
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
|
||||
|
||||
// Row sum (qf32)
|
||||
// Row sum (sf)
|
||||
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
|
||||
|
||||
|
|
@ -432,8 +433,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
|
||||
|
|
@ -456,20 +457,18 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
||||
hvx_vec_store_u(&s[0], 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
|
|
@ -493,7 +492,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
|
||||
|
||||
// Row sum (qf32)
|
||||
// Row sum (sf)
|
||||
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
|
||||
|
||||
// Multiply and accumulate into int32.
|
||||
|
|
@ -517,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
|
||||
|
|
@ -535,14 +534,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
|
|||
// Zero out unused scales
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
// Reduce and convert into fp32
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
||||
|
||||
hvx_vec_store_u(&s[0], 4, r0_sum);
|
||||
}
|
||||
|
|
@ -605,8 +604,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
|
|||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
|
||||
|
|
@ -629,20 +628,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
|
|||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
||||
hvx_vec_store_u(&s[0], 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
||||
|
|
@ -669,7 +666,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|||
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
|
||||
|
||||
// Row sum (qf32)
|
||||
// Row sum (sf)
|
||||
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
|
||||
|
||||
// Multiply and accumulate into int32.
|
||||
|
|
@ -708,7 +705,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
// Process leftovers
|
||||
|
|
@ -741,14 +738,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
|
|||
// Zero-out unused scales
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
// Reduce and convert into fp32
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
||||
|
||||
hvx_vec_store_u(&s[0], 4, r0_sum);
|
||||
}
|
||||
|
|
@ -781,13 +778,13 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
|
|||
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
|
||||
|
||||
// Row sum (qf32)
|
||||
// Row sum (sf)
|
||||
HVX_Vector r0_sum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector r1_sum = Q6_V_vsplat_R(0);
|
||||
|
||||
// Multiply and accumulate into int32.
|
||||
// Compute combined scale (fp32).
|
||||
// Apply scale to acc and accumulate into the row sum (qf32).
|
||||
// Apply scale to acc and accumulate into the row sum (f32).
|
||||
|
||||
const uint32_t nb = n / qk; // num full blocks
|
||||
int32_t nloe = n % qk; // num leftover elemements (must be signed)
|
||||
|
|
@ -829,8 +826,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
|
|||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
// Process leftovers
|
||||
|
|
@ -867,24 +864,22 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
|
|||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
|
||||
|
||||
// Zero-out unused scales
|
||||
// Zero-out unused values
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
|
||||
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
|
||||
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
||||
hvx_vec_store_u(&s[0], 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
|
|
@ -913,7 +908,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -957,11 +952,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n,
|
|||
rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
|
||||
}
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1));
|
||||
HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
|
||||
|
||||
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
|
||||
hvx_vec_store_u(&s[0], 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
|
|
@ -990,7 +982,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
|
|
@ -1042,7 +1034,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
|
|||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
|
||||
// Convert into fp32 and reduce
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -154,8 +154,8 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
|||
v_pad[i] = v3;
|
||||
}
|
||||
|
||||
v = hvx_vec_reduce_sum_qf32(sum_vec);
|
||||
sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
|
||||
v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
|
||||
sum_vec = hvx_vec_repl4(v);
|
||||
|
||||
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
|
||||
|
|
|
|||
|
|
@ -57,8 +57,8 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v);
|
||||
sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
|
||||
HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
sum_v = hvx_vec_repl4(reduced_sum);
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
|
|
|
|||
Loading…
Reference in New Issue