cpu: skip redudant ROPE cache updates (#20149)
This commit is contained in:
parent
d48e876467
commit
ba2fd11cdf
|
|
@ -5803,28 +5803,33 @@ static void ggml_compute_forward_rope_flt(
|
|||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
int64_t last_i2 = -1;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!mrope_used) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir++ < ir0) continue; // skip rows mapped to other threads
|
||||
if (ir > ir1) break;
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (last_i2 != i2) {
|
||||
if (!mrope_used) {
|
||||
const int64_t p = pos[i2];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i2];
|
||||
const int64_t p_h = pos[i2 + ne2];
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
last_i2 = i2;
|
||||
}
|
||||
|
||||
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
||||
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue