diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 19917cb114..4b8a16c363 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2362,6 +2362,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, return n_bufs; } +static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + // CONT is just a contiguous copy — reuse CPY op + req->op = HTP_OP_CPY; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + +static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_REPEAT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2449,12 +2470,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf break; case GGML_OP_UNARY: - if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { + switch (ggml_get_unary_op(t)) { + case GGML_UNARY_OP_SILU: req->op = HTP_OP_UNARY_SILU; supported = true; - } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { + break; + case GGML_UNARY_OP_GELU: req->op = HTP_OP_UNARY_GELU; supported = true; + break; + case GGML_UNARY_OP_SIGMOID: + req->op = HTP_OP_UNARY_SIGMOID; + supported = true; + break; + case GGML_UNARY_OP_NEG: + req->op = HTP_OP_UNARY_NEG; + supported = true; + break; + case GGML_UNARY_OP_EXP: + req->op = HTP_OP_UNARY_EXP; + supported = true; + break; + case GGML_UNARY_OP_SOFTPLUS: + req->op = HTP_OP_UNARY_SOFTPLUS; + supported = true; + break; + default: + break; } break; @@ -2640,16 +2682,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: - if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || - (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { - ggml_hexagon_dispatch_op(sess, node, flags); + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: + break; } break; case GGML_OP_GLU: - if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || - (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { - ggml_hexagon_dispatch_op(sess, node, flags); + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: + break; } break; case GGML_OP_SOFT_MAX: @@ -2676,6 +2730,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CONT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_REPEAT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + case GGML_OP_ARGSORT: ggml_hexagon_dispatch_op(sess, node, flags); break; @@ -3006,6 +3068,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, return true; } +static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + + // CONT is same-type only, supports f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + return true; +} + +static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // Support f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + // src and dst must be the same type + if (src0->type != dst->type) return false; + + // dst dims must be multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0) return false; + if (dst->ne[1] % src0->ne[1] != 0) return false; + if (dst->ne[2] % src0->ne[2] != 0) return false; + if (dst->ne[3] % src0->ne[3] != 0) return false; + + // require contiguous tensors (no transposition) + if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false; + + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3063,21 +3158,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_UNARY: - { - const auto unary_op = ggml_get_unary_op(op); - if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; case GGML_OP_GLU: - { - const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3098,6 +3204,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_CONT: + supp = ggml_hexagon_supported_cont(sess, op); + break; + + case GGML_OP_REPEAT: + supp = ggml_hexagon_supported_repeat(sess, op); + break; + case GGML_OP_ARGSORT: supp = ggml_hexagon_supported_argsort(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 02d07a503d..a490a2ce9a 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + repeat-ops.c argsort-ops.c ssm-conv.c ) diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 52dcc36d8f..56bc5b622c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -53,6 +53,10 @@ enum htp_op { HTP_OP_RMS_NORM, HTP_OP_UNARY_SILU, HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_GEGLU, @@ -69,6 +73,7 @@ enum htp_op { HTP_OP_SQRT, HTP_OP_SUM_ROWS, HTP_OP_SSM_CONV, + HTP_OP_REPEAT, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 2ef20936f1..f643fdc340 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 578ca288fb..3e6a8579b1 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include "hex-utils.h" #include "hvx-types.h" diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h index 44dfe232a3..84e4836dc9 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.h +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -3,6 +3,7 @@ #include #include +#include #include "hvx-base.h" #include "hvx-floor.h" @@ -16,8 +17,8 @@ #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x41a00000) // 20.0 -#define EXP_RANGE_L (0xc1a00000) // -20.0 +#define EXP_RANGE_R (0x42B16666) // 88.7 +#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector z_qf32_v; @@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector temp_v = in_vec; - // Clamp inputs to (-20.0, 20.0) + // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec); epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); @@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { // normalize before every QFloat's vmpy x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + // z = x * x; z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); - x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); - // y = E4 + E5 * x; E_const = Q6_V_vsplat_R(EXP_COEFF_5); y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); @@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max return Q6_V_vmux_QVV(pred0, inf, out); } -static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { +static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict HVX_Vector vec_out = Q6_V_vzero(); static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) + static const float kMaxExp = 88.7f; const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector inf = hvx_vec_splat_f32(kInf); diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 095193277e..37f3e7b6fa 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -2,6 +2,7 @@ #define HVX_SIGMOID_H #include "hvx-base.h" +#include "hvx-inverse.h" #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 3f99dbb32c..2a3f9e562b 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -516,6 +516,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = op_repeat(&octx); + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -1090,6 +1123,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_SQR: case HTP_OP_SQRT: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_SOFTPLUS: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -1175,6 +1212,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_REPEAT: + if (n_bufs != 2) { + FARF(ERROR, "Bad repeat-req buffer list"); + continue; + } + proc_repeat_req(ctx, &req, bufs); + break; + case HTP_OP_ARGSORT: if (n_bufs != 2) { FARF(ERROR, "Bad argsort-req buffer list"); diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c new file mode 100644 index 0000000000..5db06c920e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -0,0 +1,148 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +struct htp_repeat_context { + struct htp_ops_context * octx; + + uint32_t nr0; + uint32_t nr1; + uint32_t nr2; + uint32_t nr3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; // ne1 * ne2 * ne3 + + size_t type_size; +}; + +static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; + struct htp_ops_context * octx = rctx->octx; + const struct htp_tensor * src = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + const uint32_t ne00 = src->ne[0]; + const uint32_t ne01 = src->ne[1]; + const uint32_t ne02 = src->ne[2]; + const uint32_t ne03 = src->ne[3]; + + const uint32_t nb00 = src->nb[0]; + const uint32_t nb01 = src->nb[1]; + const uint32_t nb02 = src->nb[2]; + const uint32_t nb03 = src->nb[3]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb0 = dst->nb[0]; + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t nr0 = rctx->nr0; + const uint32_t nr1 = rctx->nr1; + const uint32_t nr2 = rctx->nr2; + const uint32_t nr3 = rctx->nr3; + + const size_t row_bytes = ne00 * rctx->type_size; + + const uint32_t row_start = rctx->nrows_per_thread * ith; + const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + // Decompose flat dst row index into (i1, i2, i3) + const uint32_t i1 = dst_row % ne1; + const uint32_t i2 = (dst_row / ne1) % ne2; + const uint32_t i3 = dst_row / (ne1 * ne2); + + // Map to source indices (tiling) + const uint32_t k1 = i1 % ne01; + const uint32_t k2 = i2 % ne02; + const uint32_t k3 = i3 % ne03; + + const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03; + uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + // Tile along dimension 0 + for (uint32_t i0 = 0; i0 < nr0; i0++) { + uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0; + memcpy(dst_ptr, src_row, row_bytes); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_repeat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + struct htp_tensor * dst = &octx->dst; + + // Validate that dst dims are multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0 || + dst->ne[1] % src0->ne[1] != 0 || + dst->ne[2] % src0->ne[2] != 0 || + dst->ne[3] % src0->ne[3] != 0) { + FARF(ERROR, "repeat: dst dims must be multiples of src dims\n"); + return HTP_STATUS_INVAL_PARAMS; + } + + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + case HTP_TYPE_F16: type_size = 2; break; + default: + FARF(ERROR, "repeat: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_repeat_context rctx = { + .octx = octx, + .nr0 = dst->ne[0] / src0->ne[0], + .nr1 = dst->ne[1] / src0->ne[1], + .nr2 = dst->ne[2] / src0->ne[2], + .nr3 = dst->ne[3] / src0->ne[3], + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + }; + + FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3); + + worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 8dae7f1ed5..d6356b9506 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src, const float max) { hvx_sub_scalar_f32(spad, src, max, num_elems); - hvx_exp_f32(spad, dst, num_elems, false); + hvx_exp_f32(dst, spad, num_elems, false); float sum = hvx_reduce_sum_f32(dst, num_elems); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 5bbd5040d3..3d0928d4dc 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -9,6 +9,8 @@ #include #include "hex-dma.h" +#include "hvx-exp.h" +#include "hvx-sigmoid.h" #include "hvx-utils.h" #define GGML_COMMON_DECL_C @@ -166,6 +168,75 @@ static void sqrt_f32(const float * restrict src, } } +static void neg_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f); + } +} + +static void exp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_exp_f32(dst_local, src_local, row_elems, false); + } +} + +static void sigmoid_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_sigmoid_f32_aa(dst_local, src_local, row_elems); + } +} + +static void softplus_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + // softplus(x) = log(1 + exp(x)) + // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + for (uint32_t i = 0; i < row_elems; i++) { + float x = src_f[i]; + // For x > 20: softplus(x) ≈ x (avoids exp overflow) + dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x)); + } + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -247,6 +318,18 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_SQRT: sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_NEG: + neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_EXP: + exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SIGMOID: + sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SOFTPLUS: + softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; default: break; } @@ -295,6 +378,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_SQRT: op_type = "sqrt-f32"; break; + case HTP_OP_UNARY_NEG: + op_type = "neg-f32"; + break; + case HTP_OP_UNARY_EXP: + op_type = "exp-f32"; + break; + case HTP_OP_UNARY_SIGMOID: + op_type = "sigmoid-f32"; + break; + case HTP_OP_UNARY_SOFTPLUS: + op_type = "softplus-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op);