ggml-hexagon: fix `rope` failure at `test-backend-ops` (#17565)
* fix test failure * fix: correct scaling calculations in rope_cache_init * fix: optimize element copying in rope_hex_f32 using memcpy * fix: optimize loop boundaries in rope_hex_f32 for better performance * feat: add profiling macros for performance measurement in operations
This commit is contained in:
parent
45e350e3d3
commit
34ce48d97a
|
|
@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
return (1 - MIN(1, MAX(0, y)));
|
return (1 - MIN(1, MAX(0, y)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_cache_init(const float theta_base,
|
static void rope_cache_init(const float theta_base,
|
||||||
float freq_scale,
|
const float freq_scale,
|
||||||
const float * freq_factors,
|
const float * freq_factors,
|
||||||
float * corr_dims,
|
float * corr_dims,
|
||||||
uint32_t ne0,
|
const uint32_t ne0,
|
||||||
float ext_factor,
|
const float ext_factor,
|
||||||
float mscale,
|
const float mscale,
|
||||||
float * cache,
|
float * cache,
|
||||||
float theta_scale) {
|
const float theta_scale) {
|
||||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||||
float theta = theta_base;
|
float theta = theta_base;
|
||||||
|
|
||||||
|
|
@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base,
|
||||||
|
|
||||||
// Get n-d rotational scaling corrected for extrapolation
|
// Get n-d rotational scaling corrected for extrapolation
|
||||||
float theta_interp = freq_scale * theta_extrap;
|
float theta_interp = freq_scale * theta_extrap;
|
||||||
float theta2 = theta_interp;
|
float theta_final = theta_interp;
|
||||||
|
float mscale_final = mscale;
|
||||||
|
|
||||||
if (ext_factor != 0.0f) {
|
if (ext_factor != 0.0f) {
|
||||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||||
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
|
|
||||||
// Get n-d magnitude scaling corrected for interpolation
|
// Get n-d magnitude scaling corrected for interpolation
|
||||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
cache[i0 + 0] = cosf(theta2) * mscale;
|
cache[i0 + 0] = cosf(theta_final) * mscale_final;
|
||||||
cache[i0 + 1] = sinf(theta2) * mscale;
|
cache[i0 + 1] = sinf(theta_final) * mscale_final;
|
||||||
|
|
||||||
theta *= theta_scale;
|
theta *= theta_scale;
|
||||||
}
|
}
|
||||||
|
|
@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
|
||||||
}
|
}
|
||||||
|
|
||||||
static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||||
float * restrict dst,
|
float * restrict dst,
|
||||||
const int num_elems,
|
const int num_elems,
|
||||||
const float * restrict theta_cache) {
|
const float * restrict theta_cache) {
|
||||||
// for (int i = 0; i < num_elems; i += 2) {
|
// for (int i = 0; i < num_elems; i += 2) {
|
||||||
//const float cos_theta = theta_cache[i + 0];
|
//const float cos_theta = theta_cache[i + 0];
|
||||||
//const float sin_theta = theta_cache[i + 1];
|
//const float sin_theta = theta_cache[i + 1];
|
||||||
|
|
@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||||
|
|
||||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
|
||||||
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
|
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
|
||||||
|
|
||||||
src0_curr += VLEN;
|
src0_curr += VLEN;
|
||||||
|
|
@ -259,7 +260,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
const uint32_t ir1,
|
const uint32_t ir1,
|
||||||
int nth,
|
int nth,
|
||||||
int ith,
|
int ith,
|
||||||
int opt_path) {
|
const int opt_path) {
|
||||||
struct htp_ops_context * octx = rope_ctx->octx;
|
struct htp_ops_context * octx = rope_ctx->octx;
|
||||||
|
|
||||||
const struct htp_tensor * src0 = &octx->src0;
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
|
@ -267,8 +268,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
const struct htp_tensor * src2 = &octx->src2;
|
const struct htp_tensor * src2 = &octx->src2;
|
||||||
struct htp_tensor * dst = &octx->dst;
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
const int32_t mode = rope_ctx->mode;
|
const int32_t mode = rope_ctx->mode;
|
||||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||||
|
|
||||||
htp_rope_preamble;
|
htp_rope_preamble;
|
||||||
|
|
||||||
|
|
@ -281,8 +282,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
freq_factors = (const float *) src2->data;
|
freq_factors = (const float *) src2->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ir = 0;
|
const uint32_t i1_end = MIN(ir1, ne1);
|
||||||
|
const int32_t half_dims = rope_ctx->n_dims / 2;
|
||||||
|
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
|
||||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||||
const int32_t p = pos[i2];
|
const int32_t p = pos[i2];
|
||||||
|
|
@ -290,14 +292,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
||||||
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
||||||
|
|
||||||
for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
|
||||||
if (ir++ < ir0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ir > ir1) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
||||||
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
||||||
|
|
||||||
|
|
@ -310,6 +305,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
} else {
|
} else {
|
||||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
src_loc += rope_ctx->n_dims;
|
||||||
|
dst_data_loc += rope_ctx->n_dims;
|
||||||
} else {
|
} else {
|
||||||
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||||
const float cos_theta = wp0[i0 + 0];
|
const float cos_theta = wp0[i0 + 0];
|
||||||
|
|
@ -317,10 +315,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
const float x0 = src_loc[0];
|
const float x0 = src_loc[0];
|
||||||
const float x1 = src_loc[rope_ctx->n_dims/2];
|
const float x1 = src_loc[half_dims];
|
||||||
|
|
||||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||||
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||||
|
|
||||||
src_loc += 1;
|
src_loc += 1;
|
||||||
dst_data_loc += 1;
|
dst_data_loc += 1;
|
||||||
|
|
@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
dst_data_loc += 2;
|
dst_data_loc += 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
src_loc += (is_neox ? half_dims : 0);
|
||||||
|
dst_data_loc += (is_neox ? half_dims : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
|
// TODO: use simd to speed up the remaining elements copy
|
||||||
dst_data_loc[0] = src_loc[0];
|
memcpy(dst_data_loc, src_loc, remain_bytes);
|
||||||
dst_data_loc[1] = src_loc[1];
|
|
||||||
|
|
||||||
src_loc += 2;
|
|
||||||
dst_data_loc += 2;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue