diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index b70da8f3b2..d6e9776b87 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2152,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + // Only support FP32 for now + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Check IO tensor shapes and dims + if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) { + return false; // src0 should be effectively 3D + } + + const int d_conv = src1->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { + return false; + } + if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { + return false; + } + if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { + return false; + } + + // TODO: add support for non-contiguous tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2468,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf return n_bufs; } +static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SSM_CONV; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2606,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SSM_CONV: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -3024,6 +3077,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_argsort(sess, op); break; + case GGML_OP_SSM_CONV: + supp = ggml_hexagon_supported_ssm_conv(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2c23b60da3..02d07a503d 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED get-rows-ops.c cpy-ops.c argsort-ops.c + ssm-conv.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 25403bb112..52dcc36d8f 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -68,6 +68,7 @@ enum htp_op { HTP_OP_SQR, HTP_OP_SQRT, HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 127ab1d665..2ef20936f1 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -41,9 +41,6 @@ struct htp_ops_context { worker_pool_context_t * wpool; // worker pool uint32_t n_threads; // num threads - uint32_t src0_nrows_per_thread; - uint32_t src1_nrows_per_thread; - uint32_t flags; }; @@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad3733..0834379879 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,4 +15,12 @@ #include "hvx-div.h" #include "hvx-base.h" +#ifndef GATHER_TYPE +# if defined(__hexagon__) +# define GATHER_TYPE(_a) (intptr_t) _a +# else +# define GATHER_TYPE(_a) (HVX_Vector *) _a +# endif +#endif + #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 92a1422896..3f99dbb32c 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup OP context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_ssm_conv(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_argsort_req(ctx, &req, bufs); break; + case HTP_OP_SSM_CONV: + if (n_bufs != 3) { + FARF(ERROR, "Bad ssm-conv-req buffer list"); + continue; + } + proc_ssm_conv_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c new file mode 100644 index 0000000000..b3c1ef9572 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -0,0 +1,339 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "hex-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +#define htp_ssm_conv_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_ssm_conv_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint64_t t_start; +}; + +#define htp_ssm_conv_preamble \ + struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// Scalar FP32 SSM_CONV implementation +static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const uint32_t d_inner_per_thread = scctx->nrows_per_thread; + const uint32_t d_inner_start = d_inner_per_thread * ith; + const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + + // No work for this thread + if (d_inner_start >= d_inner_end) { + return; + } + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) { + float sumf = 0.0f; + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const int dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); + const int ir = ir1 - ir0; + + if (ir0 >= ir1) { + return; // No work for this thread + } + + // src0 and src1 gather offsets + uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; + uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + + for (uint32_t i = 0; i < VLEN_FP32; ++i) { + src0_offsets[i] = i * (ncs) * sizeof(float); + src1_offsets[i] = i * (d_conv) * sizeof(float); + } + + const uint32_t src0_gather_len = VLEN * ncs; + const uint32_t src1_gather_len = VLEN * d_conv; + + // gather scratchpads + HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); + HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); + + float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); + float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + + uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + // copy src1 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); + + // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + + // copy src0 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + + // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + + dma_queue_flush(dma_queue); + + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + + const uint32_t nvec = ir / VLEN_FP32; + const uint32_t nloe = ir % VLEN_FP32; + uint32_t i1 = 0; + + for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); + i1 += VLEN_FP32; + } + + if (nloe) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_ssm_conv_f32(struct htp_ops_context * octx) { + htp_ssm_conv_tensors_preamble; + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); + return HTP_STATUS_NO_SUPPORT; + } + + struct htp_ssm_conv_context scctx = { 0 }; + scctx.octx = octx; + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; // tokens per sequence + const uint32_t n_s = dst->ne[2]; // number of sequences in the batch + + const uint32_t n_threads = MIN(octx->n_threads, d_inner); + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t use_hvx = 0; + if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { + int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && + hex_is_aligned((void *) src1->data, VLEN) && + hex_is_aligned((void *) dst->data, VLEN); + + if (is_aligned) { + use_hvx = 1; + } + } + + if (use_hvx) { + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + + octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); + octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); + octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + + // Compute gather scratchpad size for src0 and src1 + const size_t gather_spad_size = n_threads * VLEN * 2; + + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", + gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, + octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, + octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + + const size_t total_spad_size = + gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (total_spad_size > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, + octx->ctx->vtcm_size); + use_hvx = 0; + } + } + + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); + + if (use_hvx) { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads); + } + } + + return HTP_STATUS_OK; +} + +int op_ssm_conv(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_ssm_conv_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +}