diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 72a82a8911..c45b292a52 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2164,8 +2164,14 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } // src0, src1 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, dst)) { - return false; + if(src1){ + if (!hex_supported_buffer(sess, src0, src1, dst)) { + return false; + } + }else{ + if (!hex_supported_buffer(sess, src0, dst)) { + return false; + } } return true; @@ -2665,6 +2671,10 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { req.op = HTP_OP_UNARY_SILU; supported = true; } + else if (ggml_get_unary_op(dst) == GGML_UNARY_OP_GELU){ + req.op = HTP_OP_UNARY_GELU; + supported = true; + } break; case GGML_OP_GLU: @@ -2680,6 +2690,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { case GGML_OP_SOFT_MAX: req.op = HTP_OP_SOFTMAX; supported = true; + break; default: break; @@ -2959,6 +2970,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_UNARY: if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) { ggml_hexagon_unary(node, flags); + } else if (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU) { + ggml_hexagon_unary(node, flags); } break; case GGML_OP_GLU: @@ -3257,7 +3270,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons auto sess = static_cast(dev->context); bool supp = false; - switch (op->op) { case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -3297,6 +3309,9 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) { supp = ggml_hexagon_supported_activations(sess, op); } + else if (ggml_get_unary_op(op) == GGML_UNARY_OP_GELU){ + supp = ggml_hexagon_supported_activations(sess, op); + } break; case GGML_OP_GLU: diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 87b09cca3a..2db4a2a35b 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -255,6 +255,90 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } + +static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread) { + htp_act_preamble2; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + int is_aligned = 1; + int opt_path = 0; + if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + is_aligned = 0; + FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { + const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); + float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); + + if (ir + 1 < src0_end_row) { + htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); + } + + + // gelu = 0.5 * x * (1.0 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3) )) // gelu_tanh + // gelu = x * sigmoid(1.702 * x) // current implementation + if (1 == opt_path) { + hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0); + hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0); + + hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); + } else { + hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0); + // sigmoid + hvx_exp_f32((const uint8_t *) src0_spad_data, src0_spad_data, ne0, true); + hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0); + hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0); + + hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "gelu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02, + ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread); +} + + + static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, struct htp_tensor * dst, const int32_t * op_params, @@ -371,7 +455,10 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { act_op_func = glu_swiglu_oai_fp32; op_type = "swiglu-oai-f32"; break; - + case HTP_OP_UNARY_GELU: + act_op_func = unary_gelu_fp32; + op_type = "gelu-f32"; + break; default: FARF(ERROR, "Unsupported activations Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 9278f41f4e..a61652304a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -51,11 +51,12 @@ enum htp_op { HTP_OP_MUL_MAT_ID = 5, HTP_OP_RMS_NORM = 6, HTP_OP_UNARY_SILU = 7, - HTP_OP_GLU_SWIGLU = 8, - HTP_OP_GLU_SWIGLU_OAI = 9, - HTP_OP_SOFTMAX = 10, - HTP_OP_ADD_ID = 11, - HTP_OP_ROPE = 12, + HTP_OP_UNARY_GELU = 8, + HTP_OP_GLU_SWIGLU = 9, + HTP_OP_GLU_SWIGLU_OAI = 10, + HTP_OP_SOFTMAX = 11, + HTP_OP_ADD_ID = 12, + HTP_OP_ROPE = 13, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index b60b352a7b..e30ae69502 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -798,6 +798,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { break; case HTP_OP_UNARY_SILU: + case HTP_OP_UNARY_GELU: if (n_bufs != 2) { FARF(ERROR, "Bad act-req buffer list"); continue;