This commit is contained in:
Shouyu 2025-12-17 12:24:01 +08:00 committed by GitHub
commit 9c27ea639d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 37 deletions

View File

@ -3297,7 +3297,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
break; break;
case GGML_OP_GLU: case GGML_OP_GLU:
if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) /* || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) */) { if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) ) {
supp = ggml_hexagon_supported_activations(sess, op); supp = ggml_hexagon_supported_activations(sess, op);
} }
break; break;

View File

@ -231,7 +231,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
// x (src0_spad_data) = std::min(src0_p[k], limit); // x (src0_spad_data) = std::min(src0_p[k], limit);
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc); hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc); hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
// y (src1_spad_data) = y1 + 1.f // y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc); hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
// x1 (dst_spad_data) = alpha * (x) // x1 (dst_spad_data) = alpha * (x)

View File

@ -875,35 +875,46 @@ float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
size_t left_over = num_elems & (VLEN_FP32 - 1); size_t left_over = num_elems & (VLEN_FP32 - 1);
size_t num_elems_whole = num_elems - left_over; size_t num_elems_whole = num_elems - left_over;
int unalign_address = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
unalign_address = 1;
} }
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
const float * src_f = (const float *) src; const float * src_f = (const float *) src;
HVX_Vector vec_min = Q6_V_vsplat_R(val); HVX_Vector vec_min =hvx_vec_splat_fp32(val);
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
HVX_Vector * restrict vec_out = (HVX_Vector *) dst; if(unalign_address == 0){
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
#pragma unroll(4) #pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
*vec_out++ = Q6_Vsf_equals_Vqf32(vec_min); *vec_out++ = (min_clamp);
}
}else{
HVX_UVector * restrict vec_in = (HVX_Vector *) src;
HVX_UVector * restrict vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
*vec_out++ = (min_clamp);
}
} }
if (left_over > 0) { if (left_over > 0 ) {
const float * srcf = (const float *) src + num_elems_whole; const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole; float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf; HVX_UVector in = *(HVX_UVector *) srcf;
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in); HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in);
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min)); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp));
} }
} }
@ -914,47 +925,72 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
const int num_elems) { const int num_elems) {
size_t left_over = num_elems & (VLEN_FP32 - 1); size_t left_over = num_elems & (VLEN_FP32 - 1);
size_t num_elems_whole = num_elems - left_over; size_t num_elems_whole = num_elems - left_over;
int unalign_address = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
unalign_address = 1;
} }
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
HVX_Vector range_left = hvx_vec_splat_fp32(limit_left); HVX_Vector range_left = hvx_vec_splat_fp32(limit_left);
HVX_Vector range_right = hvx_vec_splat_fp32(limit_right); HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
#pragma unroll(4) if(unalign_address == 0){
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { HVX_Vector * restrict vec_in = (HVX_Vector *) src;
HVX_Vector in_vec = *vec_in++; HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
HVX_Vector temp_v = in_vec;
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
*vec_out++ = Q6_Vsf_equals_Vqf32(in_vec); #pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in_vec = *vec_in++;
HVX_Vector temp_v = in_vec;
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
*vec_out++ = in_vec;
}
}else{
HVX_UVector * restrict vec_in = (HVX_UVector *) src;
HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in_vec = *vec_in++;
HVX_Vector temp_v = in_vec;
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
*vec_out++ = in_vec;
}
} }
if (left_over > 0) { if (left_over > 0) {
const float * srcf = (const float *) src + num_elems_whole; const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole; float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf; HVX_Vector in_vec = *(HVX_UVector *) srcf;
HVX_Vector temp_v = in; HVX_Vector temp_v = in_vec;
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right); HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in); HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v); in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in)); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
} }
} }

View File

@ -806,6 +806,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
break; break;
case HTP_OP_GLU_SWIGLU: case HTP_OP_GLU_SWIGLU:
case HTP_OP_GLU_SWIGLU_OAI:
case HTP_OP_SOFTMAX: case HTP_OP_SOFTMAX:
if ((n_bufs != 2) && (n_bufs != 3)) { if ((n_bufs != 2) && (n_bufs != 3)) {
FARF(ERROR, "Bad act-req buffer list"); FARF(ERROR, "Bad act-req buffer list");