ggml-hexagon: fix `rope` failure at `test-backend-ops` (#17565)

* fix test failure

* fix: correct scaling calculations in rope_cache_init

* fix: optimize element copying in rope_hex_f32 using memcpy

* fix: optimize loop boundaries in rope_hex_f32 for better performance

* feat: add profiling macros for performance measurement in operations
This commit is contained in:
nullname 2025-12-11 06:45:43 +08:00 committed by GitHub
parent 45e350e3d3
commit 34ce48d97a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 37 additions and 41 deletions

View File

@ -74,14 +74,14 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
} }
static void rope_cache_init(const float theta_base, static void rope_cache_init(const float theta_base,
float freq_scale, const float freq_scale,
const float * freq_factors, const float * freq_factors,
float * corr_dims, float * corr_dims,
uint32_t ne0, const uint32_t ne0,
float ext_factor, const float ext_factor,
float mscale, const float mscale,
float * cache, float * cache,
float theta_scale) { const float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta = theta_base; float theta = theta_base;
@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base,
// Get n-d rotational scaling corrected for extrapolation // Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap; float theta_interp = freq_scale * theta_extrap;
float theta2 = theta_interp; float theta_final = theta_interp;
float mscale_final = mscale;
if (ext_factor != 0.0f) { if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation // Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
} }
cache[i0 + 0] = cosf(theta2) * mscale; cache[i0 + 0] = cosf(theta_final) * mscale_final;
cache[i0 + 1] = sinf(theta2) * mscale; cache[i0 + 1] = sinf(theta_final) * mscale_final;
theta *= theta_scale; theta *= theta_scale;
} }
@ -259,7 +260,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const uint32_t ir1, const uint32_t ir1,
int nth, int nth,
int ith, int ith,
int opt_path) { const int opt_path) {
struct htp_ops_context * octx = rope_ctx->octx; struct htp_ops_context * octx = rope_ctx->octx;
const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src0 = &octx->src0;
@ -281,8 +282,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
freq_factors = (const float *) src2->data; freq_factors = (const float *) src2->data;
} }
int ir = 0; const uint32_t i1_end = MIN(ir1, ne1);
const int32_t half_dims = rope_ctx->n_dims / 2;
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
const int32_t p = pos[i2]; const int32_t p = pos[i2];
@ -290,14 +292,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
if (ir++ < ir0) {
continue;
}
if (ir > ir1) {
break;
}
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
@ -310,6 +305,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
} else { } else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
} }
src_loc += rope_ctx->n_dims;
dst_data_loc += rope_ctx->n_dims;
} else { } else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0]; const float cos_theta = wp0[i0 + 0];
@ -317,10 +315,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
if (is_neox) { if (is_neox) {
const float x0 = src_loc[0]; const float x0 = src_loc[0];
const float x1 = src_loc[rope_ctx->n_dims/2]; const float x1 = src_loc[half_dims];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta; dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
src_loc += 1; src_loc += 1;
dst_data_loc += 1; dst_data_loc += 1;
@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
dst_data_loc += 2; dst_data_loc += 2;
} }
} }
src_loc += (is_neox ? half_dims : 0);
dst_data_loc += (is_neox ? half_dims : 0);
} }
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) { // TODO: use simd to speed up the remaining elements copy
dst_data_loc[0] = src_loc[0]; memcpy(dst_data_loc, src_loc, remain_bytes);
dst_data_loc[1] = src_loc[1];
src_loc += 2;
dst_data_loc += 2;
}
} }
} }
} }