diff --git a/README.md b/README.md index 5c11f38048..125cb3f370 100644 --- a/README.md +++ b/README.md @@ -287,7 +287,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE | | [WebGPU [In Progress]](docs/build.md#webgpu) | All | | [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | -| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon | +| [Hexagon [In Progress]](docs/backend/snapdragon/README.md) | Snapdragon | | [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR | ## Obtaining and quantizing models diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3006e21779..b70da8f3b2 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1865,15 +1865,26 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (src0->type != GGML_TYPE_F32) { - return false; - } - if (src1->type != GGML_TYPE_F32) { - return false; - } - if (dst->type != GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32) { + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (dst->type != GGML_TYPE_F32) { + return false; + } + } + else if (src0->type == GGML_TYPE_F16) { + if (src1->type != GGML_TYPE_F16) { + return false; + } + if (dst->type != GGML_TYPE_F16) { + return false; + } + } + else { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 21bd4050a1..d8b924981e 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -693,8 +693,8 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used @@ -748,13 +748,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - uint32_t n_jobs = MIN(n_threads, src0_nrows); - // Prepare context struct htp_act_context actx; actx.octx = octx; - actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; actx.src0_row_size = src0_row_size; actx.src1_row_size = src1_row_size; @@ -794,7 +792,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { actx.data_src1 = data_src1; actx.data_dst = (uint8_t *) dst->data; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index a4cee980be..170220e8f8 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -241,6 +241,9 @@ int op_argsort(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } + const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + const uint32_t n_threads = MIN(total_rows, octx->n_threads); + // Allocate scratchpad // We need 1 row of float + 1 row of int32 per thread. uint32_t ne00 = octx->src0.ne[0]; @@ -251,7 +254,7 @@ int op_argsort(struct htp_ops_context * octx) { // Make sure we round up to 256 for alignment requirements spad_per_thread = hex_round_up(spad_per_thread, 256); - size_t total_spad_size = spad_per_thread * octx->n_threads; + size_t total_spad_size = spad_per_thread * n_threads; if (octx->ctx->vtcm_size < total_spad_size) { FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); @@ -267,15 +270,12 @@ int op_argsort(struct htp_ops_context * octx) { octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], octx->src0.data, octx->dst.data); - uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; - uint32_t n_jobs = MIN(total_rows, octx->n_threads); - struct htp_argsort_context actx; actx.octx = octx; - actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads; // Run jobs - worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 00dbcf8798..ec90f22de5 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -95,43 +95,87 @@ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_ } // Macro for scalar op switch -#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ - default: break; \ +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (All Aligned) -#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) -#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (All Unaligned - generic loop used in element repeat) -#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // 1. Scalar src1 (ne10 == 1) @@ -140,6 +184,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -170,7 +216,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -199,13 +245,12 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { for (uint32_t r = 0; r < current_block_size; r++) { uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - float val = *(float *)src1_ptr; + COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00); src1_ptr += s1_stride; - COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -216,7 +261,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -230,6 +275,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -268,8 +315,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -284,7 +331,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } uint32_t i03, i02, i01, rem; @@ -293,7 +340,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi i02 = fastdiv(rem, &bctx->dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -310,8 +357,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } @@ -326,6 +373,8 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -359,7 +408,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -373,7 +422,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } uint32_t i03, i02, i01, rem; @@ -382,7 +431,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, i02 = fastdiv(rem, &bctx->dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -392,7 +441,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -406,6 +455,8 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -435,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -462,11 +513,11 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; // Read src1 from DDR (unaligned) - COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00); } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -476,7 +527,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -490,6 +541,9 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -519,7 +573,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -549,12 +603,12 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * for (uint32_t c = 0; c < ne00; c += ne10) { uint32_t len = MIN(ne10, ne00 - c); // Use UUU for speed and simplicity - COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len); } } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -564,7 +618,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -672,18 +726,20 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { dma_queue_flush(q); } -static int execute_op_binary_f32(struct htp_ops_context * octx) { +static int execute_op_binary(struct htp_ops_context * octx) { const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); // Use packed row sizes for VTCM allocation - const size_t src0_row_size = src0->ne[0] * sizeof(float); - const size_t src1_row_size = src1->ne[0] * sizeof(float); - const size_t dst_row_size = dst->ne[0] * sizeof(float); + const uint32_t src0_type = octx->src0.type; + const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const size_t src0_row_size = src0->ne[0] * elem_size; + const size_t src1_row_size = src1->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; // Align to VLEN const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); @@ -694,7 +750,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { bool is_scalar = !is_add_id && (src1->ne[0] == 1); // Determine which kernel we will use to alloc memory and dispatch - bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); @@ -726,7 +782,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { } if (rows_per_buffer < 1) { - FARF(ERROR, "binary-f32: VTCM too small\n"); + FARF(ERROR, "binary: VTCM too small\n"); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -761,16 +817,14 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - uint32_t n_jobs = MIN(n_threads, src0_nrows); - dma_queue * q = octx->ctx->dma[0]; if (is_row_bcast) { - dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1); } struct htp_binary_context bctx; bctx.octx = octx; - bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bctx.block_max = rows_per_buffer; bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned; @@ -814,14 +868,24 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { dma_queue_pop(q); } - worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads); return HTP_STATUS_OK; } int op_binary(struct htp_ops_context * octx) { - if (octx->src0.type == HTP_TYPE_F32) { - return execute_op_binary_f32(octx); + + // Does not support permutations of src1 + const struct htp_tensor * src1 = &octx->src1; + if (src1->nb[1] < src1->nb[0]) { + return HTP_STATUS_NO_SUPPORT; } + + const uint32_t src0_type = octx->src0.type; + if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { + return execute_op_binary(octx); + } + return HTP_STATUS_NO_SUPPORT; } + diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 559ca18378..a40d866b9c 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -202,6 +202,8 @@ static void cpy_work_func(unsigned int n, unsigned int i, void *data) { int op_cpy(struct htp_ops_context * octx) { cpy_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + struct htp_copy_context ct; ct.octx = octx; @@ -227,8 +229,7 @@ int op_cpy(struct htp_ops_context * octx) { const bool transposed = (nb00 > nb01) || (nb0 > nb1); const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; if (sametype && sameshape) { ct.copy = cpy_thread_sametype_sameshape; @@ -245,7 +246,7 @@ int op_cpy(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index bf24bbda70..047d2850aa 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -82,6 +82,8 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da int op_get_rows(struct htp_ops_context * octx) { get_rows_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + if (octx->src0.type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -103,9 +105,8 @@ int op_get_rows(struct htp_ops_context * octx) { grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h index 2577cdd041..82e3416970 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-arith.h +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -13,14 +13,15 @@ // Binary operations (add, mul, sub) // -#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \ +#define UNUSED(x) (void)(x) + +#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const uint32_t elem_size = sizeof(float); \ - const uint32_t epv = 128 / elem_size; \ + const uint32_t epv = 128 / (elem_size); \ const uint32_t nvec = n / epv; \ const uint32_t nloe = n % epv; \ \ @@ -32,62 +33,74 @@ } \ if (nloe) { \ HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \ - vec_store((void *) &vdst[i], nloe * elem_size, v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ } \ } while(0) #if __HVX_ARCH__ < 79 -#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) -#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) -#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + #else -#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b) -#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b) -#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + #endif +#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b) +#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b) +#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b) + // Generic macro to define alignment permutations for an op -#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \ static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src0 % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src0 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src0 % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src0 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float) + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16) // Dispatcher logic #define HVX_BINARY_DISPATCHER(OP_NAME) \ @@ -115,6 +128,10 @@ HVX_BINARY_DISPATCHER(hvx_add_f32) HVX_BINARY_DISPATCHER(hvx_sub_f32) HVX_BINARY_DISPATCHER(hvx_mul_f32) +HVX_BINARY_DISPATCHER(hvx_add_f16) +HVX_BINARY_DISPATCHER(hvx_sub_f16) +HVX_BINARY_DISPATCHER(hvx_mul_f16) + // Mul-Mul Optimized static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { assert((unsigned long) dst % 128 == 0); @@ -136,26 +153,25 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re _Pragma("unroll(4)") for (; i < nvec; i++) { - HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); vdst[i] = HVX_OP_MUL(v1, vsrc2[i]); } if (nloe) { - HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); - HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]); + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); + HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]); hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2); } } // Scalar Operations -#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \ +#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const uint32_t elem_size = sizeof(float); \ - const uint32_t epv = 128 / elem_size; \ + const uint32_t epv = 128 / (elem_size); \ const uint32_t nvec = n / epv; \ const uint32_t nloe = n % epv; \ \ @@ -169,138 +185,88 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re if (nloe) { \ HVX_Vector v = vsrc[i]; \ v = scalar_op_macro(v); \ - vec_store((void *) &vdst[i], nloe * elem_size, v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ } \ } while(0) -#define HVX_OP_ADD_SCALAR(v) \ +#define HVX_OP_ADD_SCALAR_F32(v) \ ({ \ const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \ - HVX_Vector out = HVX_OP_ADD(v, val_vec); \ + HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \ Q6_V_vmux_QVV(pred_inf, inf, out); \ }) -#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec) -#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec) +#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec) +#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec) -// Add Scalar Variants +#define HVX_OP_ADD_SCALAR_F16(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \ + HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) -static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); +#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec) +#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec) + +// Scalar Variants + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ + +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float) + +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16) + +// Dispatcher logic +#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, val, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, val, num_elems); \ + } \ } -static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); -} +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float) -static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - static const float kInf = INFINITY; - const HVX_Vector inf = hvx_vec_splat_f32(kInf); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); -} - -// Sub Scalar Variants - -static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); -} - -// Mul Scalar Variants - -static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); -} - -static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); -} - -static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); -} - -static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); -} - -static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_add_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_add_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_add_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_add_scalar_f32_uu(dst, src, val, num_elems); - } -} - -static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_mul_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_mul_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_mul_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_mul_scalar_f32_uu(dst, src, val, num_elems); - } -} - -static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_sub_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_sub_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_sub_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_sub_scalar_f32_uu(dst, src, val, num_elems); - } -} +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16) // MIN Scalar variants @@ -310,24 +276,24 @@ static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { @@ -357,27 +323,27 @@ static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) { @@ -396,7 +362,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * // Square // -#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ +#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ @@ -410,10 +376,10 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ } \ if (nloe) { \ - HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ vec_store((void *) &vdst[i], nloe * elem_size, v); \ } \ } while(0) @@ -421,21 +387,21 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); } static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); - hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); } static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) src % 128 == 0); - hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); } static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); } static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { @@ -454,17 +420,24 @@ static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict } } -#undef HVX_OP_ADD -#undef HVX_OP_SUB -#undef HVX_OP_MUL +#undef HVX_OP_ADD_F32 +#undef HVX_OP_SUB_F32 +#undef HVX_OP_MUL_F32 +#undef HVX_OP_ADD_F16 +#undef HVX_OP_SUB_F16 +#undef HVX_OP_MUL_F16 #undef hvx_arith_loop_body -#undef HVX_OP_ADD_SCALAR -#undef HVX_OP_SUB_SCALAR -#undef HVX_OP_MUL_SCALAR +#undef HVX_OP_ADD_SCALAR_F32 +#undef HVX_OP_SUB_SCALAR_F32 +#undef HVX_OP_MUL_SCALAR_F32 +#undef HVX_OP_ADD_SCALAR_F16 +#undef HVX_OP_SUB_SCALAR_F16 +#undef HVX_OP_MUL_SCALAR_F16 #undef hvx_scalar_loop_body #undef HVX_OP_MIN_SCALAR #undef HVX_OP_CLAMP_SCALAR #undef DEFINE_HVX_BINARY_OP_VARIANTS #undef HVX_BINARY_DISPATCHER +#undef UNUSED #endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 701637f22b..578ca288fb 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -189,4 +189,52 @@ static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vect #endif +#if __HVX_ARCH__ < 79 + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); +} + +#else + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vadd_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vsub_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vmpy_VhfVhf(a, b); +} + +#endif // __HVX_ARCH__ < 79 + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 7dae012e0e..05cefea039 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -15,11 +15,144 @@ #include "hvx-arith.h" #if __HVX_ARCH__ < 79 -#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) #else -#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif +// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. +static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32)); + HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32)); +#else + HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32); + HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32); +#endif + + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const); + +#if __HVX_ARCH__ < 79 + HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + return res; +} + +#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + +static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) dst % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} +static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32. +static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32)); + HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32)); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32)); + HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32)); +#else + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32); + HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32); + HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32); +#endif + + // Inverse second input in fp32 + HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask); + HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask); + + // Multiply first input by inverse of second, in fp32 + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1); + + // Convert back to fp16 +#if __HVX_ARCH__ < 79 + HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + + return recip; +} + +#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + #define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ @@ -36,81 +169,83 @@ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ - HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ vdst[i] = res; \ } \ if (nloe) { \ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ - HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ } \ } while(0) -// 3-letter suffix variants -static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src0 % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_DIV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ } -static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src0 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); -} +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body) +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body) -static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); -} +HVX_DIV_DISPATCHER(hvx_div_f32) +HVX_DIV_DISPATCHER(hvx_div_f16) -static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); -} - -static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src0 % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src0 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128)) { - if (hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); - else hvx_div_f32_aau(dst, src0, src1, num_elems); - } else { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); - else hvx_div_f32_auu(dst, src0, src1, num_elems); - } - } else { - if (hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); - else hvx_div_f32_uau(dst, src0, src1, num_elems); - } else { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); - else hvx_div_f32_uuu(dst, src0, src1, num_elems); - } - } -} - -#undef HVX_OP_MUL +#undef HVX_OP_MUL_F32 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/ggml/src/ggml-hexagon/htp/hvx-inverse.h index 53db94aae2..f2054f45ba 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.h +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.h @@ -137,40 +137,74 @@ static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector n } \ } while(0) -static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f16(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); } -static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \ + } \ + } while(0) + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_INV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, num_elems); \ + } \ } -static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) src % 128 == 0); - hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); -} +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body) +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body) -static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) { - if ((unsigned long) dst % 128 == 0) { - if ((unsigned long) src % 128 == 0) { - hvx_inverse_f32_aa(dst, src, num_elems); - } else { - hvx_inverse_f32_au(dst, src, num_elems); - } - } else { - if ((unsigned long) src % 128 == 0) { - hvx_inverse_f32_ua(dst, src, num_elems); - } else { - hvx_inverse_f32_uu(dst, src, num_elems); - } - } -} +HVX_INV_DISPATCHER(hvx_inverse_f32) +HVX_INV_DISPATCHER(hvx_inverse_f16) #endif // HVX_INVERSE_H diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 9aeb80d0b8..be9469538f 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -400,7 +400,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t ne0 = dst->ne[0]; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; @@ -465,17 +467,14 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.dst_row_size_aligned = dst_row_size_aligned; rctx.theta_cache_offset = theta_cache_size_aligned; - uint32_t ne0 = dst->ne[0]; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; rctx.src0_nrows = src0_nrows; + rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 2fd6c90772..4b6967749f 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -128,6 +128,8 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da int op_set_rows(struct htp_ops_context * octx) { set_rows_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + if (octx->src0.type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -149,15 +151,14 @@ int op_set_rows(struct htp_ops_context * octx) { srctx.div_ne12 = init_fastdiv_values(ne12); srctx.div_ne11 = init_fastdiv_values(ne11); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; switch(octx->dst.type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 6e22eb6a63..8dae7f1ed5 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -353,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; @@ -393,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs); + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 04fa72182a..352650b689 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -102,11 +102,9 @@ int op_sum_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = ne01 * ne02 * ne03; - - uint32_t n_jobs = MIN(n_threads, src0_nrows); - uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bool opt_path = false; if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { @@ -124,7 +122,7 @@ int op_sum_rows(struct htp_ops_context * octx) { .opt_path = opt_path, }; - worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 98135c50ab..5bbd5040d3 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -301,8 +301,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; @@ -338,11 +338,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - struct htp_unary_context uctx = { .octx = octx, - .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs, + .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads, .src0_nrows = src0_nrows, .data_src0 = (const uint8_t *)src0->data, @@ -361,7 +359,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .nc = src0->ne[0], }; - worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); } return err;