From fbd441c37933550c1e3365dc84dd73232334c15d Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 1 Apr 2026 17:44:02 -0700 Subject: [PATCH] hexagon : add cumsum op support (#21246) * hexagon : add cumsum op support * hexagon: enable dma for cumsum op * Fix line-ending --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 34 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/cumsum-ops.c | 267 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-msg.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 43 ++++ 6 files changed, 347 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/cumsum-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index dd604db433..f91bc46552 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2231,6 +2231,22 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2399,6 +2415,16 @@ static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_CUMSUM; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2780,6 +2806,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CUMSUM: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -3254,6 +3284,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_ssm_conv(sess, op); break; + case GGML_OP_CUMSUM: + supp = ggml_hexagon_supported_cumsum(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 6ddfe4252f..2b60f427ad 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -33,6 +33,7 @@ add_library(${HTP_LIB} SHARED repeat-ops.c argsort-ops.c ssm-conv.c + cumsum-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c new file mode 100644 index 0000000000..ce51555a7f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" +#include "hex-dma.h" + +#define htp_cumsum_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_cumsum_context { + struct htp_ops_context * octx; + size_t src_row_size; + size_t dst_row_size; + size_t src_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t rows_per_thread; + uint32_t total_rows; +}; + +#define htp_cumsum_preamble \ + struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \ + struct htp_ops_context * octx = cctx->octx; \ + htp_cumsum_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX prefix scan helpers +// --------------------------------------------------------------------------- + +#if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} +#else +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} +#endif // __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64)); + v = hvx_cumsum_vadd(v, carry_in); + + return v; +} + +static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) { + return hvx_vec_repl4(Q6_V_vror_VR(v, 124)); +} + +static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) { + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + HVX_Vector carry = Q6_V_vsplat_R(0); + + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32)); + v = hvx_prefix_scan_f32(v, carry); + hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v); + carry = hvx_splat_last_f32(v); + } + + if (nloe) { + float acc = hvx_vec_get_f32(carry); + const float * src_tail = src + nvec * VLEN_FP32; + float * dst_tail = dst + nvec * VLEN_FP32; + for (uint32_t i = 0; i < nloe; i++) { + acc += src_tail[i]; + dst_tail[i] = acc; + } + } +} + +// --------------------------------------------------------------------------- +// Per thread worker: Double-buffered DMA +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + if (ir0 >= ir1) { + return; + } + + const size_t src_row_size = cctx->src_row_size; + const size_t dst_row_size = cctx->dst_row_size; + const size_t src_row_size_aligned = cctx->src_row_size_aligned; + const size_t dst_row_size_aligned = cctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2); + + for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) { + // Dummy dst writeback to establish queue ordering + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned), + src_data + (ir * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + + for (uint32_t ir = ir0; ir < ir1; ir++) { + float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src; + float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst; + + hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00); + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < ir1) { + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + for (uint32_t ir = ir0; ir < ir1; ir++) { + const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size); + float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size); + hvx_cumsum_row_f32(src_row, dst_row, ne00); + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_cumsum_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_rows); + + const size_t src_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 2 ping-pong buffers per thread for src and dst + const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned); + + octx->src0_spad.size_per_thread = src_row_size_aligned * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + struct htp_cumsum_context cctx = { + .octx = octx, + .src_row_size = src_row_size, + .dst_row_size = dst_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .rows_per_thread = (total_rows + n_threads - 1) / n_threads, + .total_rows = total_rows, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_cumsum(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_cumsum_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 391148be0e..df0ea7ccbd 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -75,6 +75,7 @@ enum htp_op { HTP_OP_SUM_ROWS, HTP_OP_SSM_CONV, HTP_OP_REPEAT, + HTP_OP_CUMSUM, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index f643fdc340..d35decaac2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -60,5 +60,6 @@ int op_cpy(struct htp_ops_context * octx); int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 49f34b5f7d..6f37bf9d4b 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -860,6 +860,41 @@ static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_cumsum(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1474,6 +1509,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_ssm_conv_req(ctx, &req, bufs); break; + case HTP_OP_CUMSUM: + if (n_bufs != 2) { + FARF(ERROR, "Bad cumsum-req buffer list"); + continue; + } + proc_cumsum_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break;