diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index 170220e8f8..3ec26a4c1a 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -164,6 +164,12 @@ static void quicksort_values_indices_desc(float * values, int32_t * indices, int if (i < right) quicksort_values_indices_desc(values, indices, i, right); } +// LUT for ramp initialization of argsort output (first 32 members) +int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 +}; + static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { struct htp_argsort_context * actx = (struct htp_argsort_context *)data; struct htp_ops_context * octx = actx->octx; @@ -205,8 +211,12 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { // Padded to 128 bytes. size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t))); float * values_buf = (float *) spad; int32_t * indices_buf = (int32_t *) (spad + values_size); + HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size); + const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut; + const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32); for (uint32_t r = start_row; r < end_row; r++) { uint32_t src_offset = r * nb01; @@ -218,9 +228,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); - // Initialize indices - for (uint32_t j = 0; j < ne00; j++) { - indices_buf[j] = j; + // Initialize indices - Start with values 0..31, add 32 for additional vec iterations + HVX_Vector curr_ind_vec = ind_init_vec; + for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) { + indices_buf_vec[j_vec] = curr_ind_vec; + curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec); } // Sort values and mirror swaps to indices