Squashed commit of the following:

commit 8e0c6ec42a6436e037a7cc43418fa50baa130ca2
Author: Hongrui Chen <chraac@gmail.com>
Date:   Tue Oct 7 21:56:37 2025 +0800

    wip

commit 30186bb8894c620168797d4d42261d555a27eed6
Author: Hongrui Chen <chraac@gmail.com>
Date:   Tue Oct 7 21:46:34 2025 +0800

    wip

commit 3e75f5dc1dd6e64e0db2624f2e5e894f46317eed
Author: Hongrui Chen <chraac@gmail.com>
Date:   Tue Oct 7 21:45:14 2025 +0800

    fix comment

commit fe1090c8181fac9be935d606325d406a34b78a11
Author: Hongrui Chen <chraac@gmail.com>
Date:   Tue Oct 7 21:36:53 2025 +0800

    revert changes at tester

commit ac0c2a4022e8fb300f66677edc181aec49faf171
Author: Hongrui Chen <chraac@gmail.com>
Date:   Tue Oct 7 20:33:54 2025 +0800

    try enable mul thread in rope

commit 9f8ca968c28c2e320a03ffbe8adeb54f27811f21
Author: Hongrui Chen <chraac@gmail.com>
Date:   Tue Oct 7 20:13:13 2025 +0800

    disable multi thread at rope

commit eed97ca12a5c35f5697cd2f4796611915acf47d8
Author: chraac <chraac@gmail.com>
Date:   Tue Oct 7 16:35:28 2025 +0800

    add tests

commit c3ad7229bf1e5fb33ec119f6ddc5a63fe342b54f
Author: chraac <chraac@gmail.com>
Date:   Tue Oct 7 16:35:19 2025 +0800

    wip
This commit is contained in:
Hongrui Chen 2025-10-07 22:44:02 +08:00
parent ca4d2778d9
commit 40893e58c6
1 changed files with 12 additions and 37 deletions

View File

@ -187,11 +187,12 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) {
freq_factors = src2->get_read_buffer_as<float>();
}
const int64_t total_rows = out->get_ne(3) * out->get_ne(2) * out->get_ne(1);
const auto start_end_row = params->get_work_slice(total_rows);
const int64_t total_planes = out->get_ne(3) * out->get_ne(2);
const auto start_end_plane =
std::pair<int64_t, int64_t>{ start_end_row.first / out->get_ne(1),
(start_end_row.second + out->get_ne(1) - 1) / out->get_ne(1) };
params->get_work_slice(total_planes); // TODO: figure out how to use row slice for inplace rope
if (start_end_plane.first >= start_end_plane.second) {
return true;
}
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), rope);
@ -206,46 +207,22 @@ bool rope_impl(hexagon::tensor * out, hexagon::compute_params * params) {
if constexpr (!_IsMrope) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
const int64_t p = pos[i2];
rope_cache_init(p,
freq_scale,
freq_factors,
corr_dims,
out->get_ne(0),
ext_factor,
attn_factor,
cache,
sin_sign,
theta_scale);
rope_cache_init(p, freq_scale, freq_factors, corr_dims, out->get_ne(0), ext_factor, attn_factor, cache,
sin_sign, theta_scale);
} else {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 0, cache);
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + out->get_ne(2)];
const int64_t p_w = pos[i2 + out->get_ne(2) * 2];
const int64_t p_e = pos[i2 + out->get_ne(2) * 3];
mrope_cache_init(p_t,
p_h,
p_w,
p_e,
sections,
_IsVision,
freq_scale,
freq_factors,
corr_dims,
out->get_ne(0),
ext_factor,
attn_factor,
cache,
sin_sign,
theta_scale);
mrope_cache_init(p_t, p_h, p_w, p_e, sections, _IsVision, freq_scale, freq_factors, corr_dims,
out->get_ne(0), ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(rope, 1, loop);
const uint8_t * src0_plane = src0_data_ptr + i3 * src0->get_nb(3) + i2 * src0->get_nb(2);
uint8_t * dst_plane = dst_data_ptr + i3 * out->get_nb(3) + i2 * out->get_nb(2);
const int64_t start_row = ip == start_end_plane.first ? (start_end_row.first % out->get_ne(1)) : 0;
const int64_t end_row = ip == start_end_plane.second ? (start_end_row.second % out->get_ne(1)) :
out->get_ne(1); // end row is exclusive
for (int64_t i1 = start_row; i1 < end_row; i1++) { // attn-heads
for (int64_t i1 = 0; i1 < out->get_ne(1); i1++) { // attn-heads
const uint8_t * src0_row = src0_plane + i1 * src0->get_nb(1);
uint8_t * dst_row = dst_plane + i1 * out->get_nb(1);
if constexpr (_IsNeoX || _IsMrope) {
@ -385,10 +362,8 @@ bool is_rope_supported(const npu_device_tensor_op_spec * op_spec,
const auto & src0 = srcs[0];
if (src0.type != dst->type) {
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n",
op_get_name(op),
get_type_name(src0.type),
get_type_name(dst->type));
DEVICE_LOG_DEBUG("[%s]src0 type is not the same as dst type: %s vs %s\n", op_get_name(op),
get_type_name(src0.type), get_type_name(dst->type));
return false; // unsupported src0 type
}