From 95ea9e0861b28adca740dbc09494f72105c9b92b Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 6 Jan 2026 17:38:29 -0800 Subject: [PATCH] Hexagon add support for f16/f32 flash attention, scale, set-rows and improve f16/32 matmul (#18611) * hexagon: improve fp16 matmul and add fp32/fp16 flash-attention * hexagon: add support for set-rows fp32 -> fp16 with i32/i64 row-idx * hexagon: add support for SCALE fp32 * hexagon: replace scalar fp32 -> fp16 copy with HVX * hexagon: optimize flash_atten_ext with aligned VTCM buffers and DMA - Implements double-buffered DMA prefetching for K, V, and Mask tensors. - Ensures K and V rows in VTCM are padded to 128 bytes to support aligned HVX operations. - Correctly synchronizes DMA transfers to prevent race conditions. - Uses `FLASH_ATTN_BLOCK_SIZE` of 128 for efficient chunking. * hexagon: use aligned mad_f16 * hexagon: flash_atten more aligned ops * hexagon: optimize scale_f32 hvx helpers * hexagon: unroll fa loops * hexagon: remove unused set-rows log * hexagon: flash_attn_ext add support for DMAing Q - Update `op_flash_attn_ext` to include Q row size in scratchpad allocation. - Pad Q row size to 128 bytes for alignment. - Implement DMA transfer for Q tensor in `flash_attn_ext_f16_thread`. - Update dot product computations to use VTCM-buffered Q data. * hexagon: fix handling of NANs hvx dotproducts * hexagon: cleanup spad allocation in flash-atten * hexagon: improve fp16/fp32 matmul - Introduced `vec_dot_f16_f16` and `vec_dot_f16_f16_rx2` kernels using efficient HVX dot product intrinsics. - Added `quantize_fp32_f16` to copy/convert weights from DDR to VTCM - Updated `op_matmul` to use the optimized path when VTCM capacity allows and broadcasting requirements are compatible. - Implemented fallback logic to the original implementation for complex broadcasting scenarios. * hexagon: fix HVX_ARCH check * hexagon: matmul cleanup and fp16 fixes Use aligned vec_dot_f16 for 2d matmuls and unaligned version for 4d. * hexagon: fix fp16 x fp16 matmuls and some minor refactoring * hexagon: add support for GET_ROWS f32 -> f32 Also optimize SET_ROWS threading a bit when we have just a few rows to process. * hexagon: optimize set-rows threading * hexagon: update adb/run-bench.sh to properly support experimental and verbose options * hexagon: flash_atten use aligned vectors for dot products --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 161 +++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 3 + ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 566 +++++++++++++ ggml/src/ggml-hexagon/htp/get-rows-ops.c | 112 +++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 5 - ggml/src/ggml-hexagon/htp/htp-msg.h | 10 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 28 + ggml/src/ggml-hexagon/htp/hvx-utils.c | 51 +- ggml/src/ggml-hexagon/htp/hvx-utils.h | 265 +++++- ggml/src/ggml-hexagon/htp/main.c | 162 +++- ggml/src/ggml-hexagon/htp/matmul-ops.c | 912 ++++++++++++--------- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 168 ++++ ggml/src/ggml-hexagon/htp/softmax-ops.c | 4 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 34 +- scripts/snapdragon/adb/run-bench.sh | 14 +- 15 files changed, 2018 insertions(+), 477 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/flash-attn-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/get-rows-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/set-rows-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 13b96d61f8..365a24b496 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_ return true; } +static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * src3 = op->src[3]; + const struct ggml_tensor * src4 = op->src[4]; + const struct ggml_tensor * dst = op; + + // Check for F16 support only as requested + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) { + return false; + } + + if (src3 && src3->type != GGML_TYPE_F16) { // mask + return false; + } + + if (src4 && src4->type != GGML_TYPE_F32) { // sinks + return false; + } + + // For now we support F32 or F16 output as htp backend often converts output on the fly if needed, + // but the op implementation writes to F16 or F32. + // Let's assume dst can be F32 or F16. + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + return opt_experimental; +} + static bool hex_supported_src0_type(ggml_type t) { return t == GGML_TYPE_F32; } @@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if (dst->type != GGML_TYPE_F32) { return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } @@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; // typically the lm-head which would be too large for VTCM } - // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false; if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { return false; } @@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session } break; - case GGML_TYPE_F16: - if (!opt_experimental) { - return false; - } - break; - default: return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return true; } +static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F16) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t d->offset = (uint8_t *) t->data - buf->base; d->size = ggml_nbytes(t); + if (!d->size) { + // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty + d->size = 64; + } + switch (type) { case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: // Flush CPU @@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_GET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); @@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SCALE: + req->op = HTP_OP_SCALE; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs return n_bufs; } +static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_FLASH_ATTN_EXT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: @@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_SET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_GET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); break; @@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_rope(sess, op); break; + case GGML_OP_FLASH_ATTN_EXT: + supp = ggml_hexagon_supported_flash_attn_ext(sess, op); + break; + + case GGML_OP_SET_ROWS: + supp = ggml_hexagon_supported_set_rows(sess, op); + break; + + case GGML_OP_GET_ROWS: + supp = ggml_hexagon_supported_get_rows(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2cf8aaa42a..6a34a215fa 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED softmax-ops.c act-ops.c rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c new file mode 100644 index 0000000000..04a7b843ce --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -0,0 +1,566 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + + hvx_vec_store_u(r, 4, rsum); +} + +// Dot product of two F16 vectors, accumulating to float +static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(r, 4, rsum); +} + +// MAD: y (F32) += x (F16) * v (float) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { + const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S = hvx_vec_splat_fp16(s); + + uint32_t i = 0; + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); + ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + } + + if (nloe) { + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + + HVX_Vector xs = Q6_V_lo_W(xs_p); + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_u(&ptr_y[i], nloe * 4, xy); + } + } +} + +#define FLASH_ATTN_BLOCK_SIZE 128 + +static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; + struct htp_tensor * dst = &octx->dst; + + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + const uint32_t nek0 = k->ne[0]; + const uint32_t nek1 = k->ne[1]; + const uint32_t nek2 = k->ne[2]; + const uint32_t nek3 = k->ne[3]; + + const uint32_t nev0 = v->ne[0]; + const uint32_t nev1 = v->ne[1]; + const uint32_t nev2 = v->ne[2]; + const uint32_t nev3 = v->ne[3]; + + const uint32_t nbq1 = q->nb[1]; + const uint32_t nbq2 = q->nb[2]; + const uint32_t nbq3 = q->nb[3]; + + const uint32_t nbk1 = k->nb[1]; + const uint32_t nbk2 = k->nb[2]; + const uint32_t nbk3 = k->nb[3]; + + const uint32_t nbv1 = v->nb[1]; + const uint32_t nbv2 = v->nb[2]; + const uint32_t nbv3 = v->nb[3]; + + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // total rows in q + const uint32_t nr = neq1*neq2*neq3; + + const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, nr); + + if (ir0 >= ir1) return; + + dma_queue * dma = octx->ctx->dma[ith]; + + const uint32_t DK = nek0; + const uint32_t DV = nev0; + + const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); + const size_t size_q_row_padded = htp_round_up(size_q_row, 128); + + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask + + const size_t size_k_row_padded = htp_round_up(size_k_row, 128); + const size_t size_v_row_padded = htp_round_up(size_v_row, 128); + + const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator + uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; + uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith; + uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith; + uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; + uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); + + const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + + const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + + // Fetch Q row + const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + // Clear accumulator + float * VKQ32 = (float *) spad_a; + memset(VKQ32, 0, DV * sizeof(float)); + + const __fp16 * mp_base = NULL; + if (mask) { + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); + } + + const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + // Prefetch first two blocks + for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); + uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + // Mask is 1D contiguous for this row + dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + } + } + + const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + + for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // Wait for DMA + uint8_t * k_base = dma_queue_pop(dma).dst; // K + uint8_t * v_base = dma_queue_pop(dma).dst; // V + __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + + // Inner loop processing the block from VTCM + uint32_t ic = 0; + + // Process in blocks of 32 (VLEN_FP32) + for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + // 1. Compute scores + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } + } + + HVX_Vector scores = *(HVX_Vector *) scores_arr; + + // 2. Softcap + if (logit_softcap != 0.0f) { + scores = hvx_vec_tanh_fp32(scores); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 3. Mask + if (mask) { + const __fp16 * mp = m_base + ic; + HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp; + + HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00); + HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16); + + HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair)); + + HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec); + scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 4. Online Softmax Update + HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores); + float m_block = hvx_vec_get_fp32(v_max); + + float M_old = M; + float M_new = (m_block > M) ? m_block : M; + M = M_new; + + float ms = expf(M_old - M_new); + + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + S = S * ms; + + HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P); + float p_sum = hvx_vec_get_fp32(p_sum_vec); + S += p_sum; + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } + } + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * size_k_row_padded; + + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } + + if (logit_softcap != 0.0f) { + s_val = logit_softcap * tanhf(s_val); + } + + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } + + const float Mold = M; + float ms = 1.0f; + float vs = 1.0f; + + if (s_val > M) { + M = s_val; + ms = expf(Mold - M); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s_val - M); + } + + const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + + S = S * ms + vs; + } + + // Issue DMA for next+1 block (if exists) + if (ib + 2 < n_blocks) { + const uint32_t next_ib = ib + 2; + const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); + dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + } + } + } + + // sinks + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv); + + // Store result + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // dst is permuted + uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + + if (dst->type == HTP_TYPE_F32) { + hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } else if (dst->type == HTP_TYPE_F16) { + hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } + } +} + +static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + flash_attn_ext_f16_thread(octx, i, n); +} + +int op_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; + struct htp_tensor * dst = &octx->dst; + + // Check support + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || + k->type != HTP_TYPE_F16 || + v->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + octx->src0_div1 = init_fastdiv_values(q->ne[1]); + + octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + + if (mask) { + octx->src3_div2 = init_fastdiv_values(mask->ne[2]); + octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); + size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128); + size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = size_q_row_padded * 1; // single row for now + size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + + octx->src0_spad.size_per_thread = size_q_block * 1; + octx->src1_spad.size_per_thread = size_k_block * 2; + octx->src2_spad.size_per_thread = size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->dst_spad.size_per_thread = size_vkq_acc; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads; + octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size; + + if (octx->ctx->vtcm_size < total_spad) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c new file mode 100644 index 0000000000..54321421eb --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define get_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t nr = ne10 * ne11 * ne12; + +static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + get_rows_preamble; + + // parallelize by src1 elements (which correspond to dst rows) + const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t rem = i - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + // invalid index, skip for now to avoid crash + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + + return HTP_STATUS_OK; +} + +static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 5c3d217f1c..4bd0ea7a36 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -11,11 +11,6 @@ #define HTP_MAX_NTHREADS 10 -// FIXME: move these into matmul-ops -#define HTP_SPAD_SRC0_NROWS 16 -#define HTP_SPAD_SRC1_NROWS 16 -#define HTP_SPAD_DST_NROWS 2 - // Main context for htp DSP backend struct htp_context { dspqueue_t queue; diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index a61652304a..846d061784 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -36,6 +36,8 @@ enum htp_data_type { HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, HTP_TYPE_MXFP4 = 39, HTP_TYPE_COUNT }; @@ -57,6 +59,10 @@ enum htp_op { HTP_OP_SOFTMAX = 11, HTP_OP_ADD_ID = 12, HTP_OP_ROPE = 13, + HTP_OP_FLASH_ATTN_EXT = 14, + HTP_OP_SET_ROWS = 15, + HTP_OP_SCALE = 16, + HTP_OP_GET_ROWS = 17, INVALID }; @@ -137,6 +143,8 @@ struct htp_general_req { struct htp_tensor src0; // Input0 tensor struct htp_tensor src1; // Input1 tensor struct htp_tensor src2; // Input2 tensor + struct htp_tensor src3; // Input3 tensor + struct htp_tensor src4; // Input4 tensor struct htp_tensor dst; // Output tensor // should be multiple of 64 bytes (cacheline) @@ -152,6 +160,6 @@ struct htp_general_rsp { }; #define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 4 +#define HTP_MAX_PACKET_BUFFERS 8 #endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index e87657436f..7c828ae636 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -13,6 +13,7 @@ struct htp_spad { uint8_t * data; + size_t stride; size_t size; size_t size_per_thread; }; @@ -26,11 +27,14 @@ struct htp_ops_context { struct htp_tensor src0; struct htp_tensor src1; struct htp_tensor src2; + struct htp_tensor src3; + struct htp_tensor src4; struct htp_tensor dst; struct htp_spad src0_spad; struct htp_spad src1_spad; struct htp_spad src2_spad; + struct htp_spad src3_spad; struct htp_spad dst_spad; worker_pool_context_t * wpool; // worker pool @@ -49,6 +53,27 @@ struct htp_ops_context { struct fastdiv_values src1_div3; // fastdiv values for ne3 struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 + struct fastdiv_values src3_div1; // fastdiv values for ne1 + struct fastdiv_values src3_div2; // fastdiv values for ne2 + struct fastdiv_values src3_div3; // fastdiv values for ne3 + struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 + struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 + struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 + struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 + + struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 + struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 + + struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 + struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 + uint32_t flags; }; @@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c index f9e02ab67e..29d73b8622 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c @@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); } -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); - - if (0 == unaligned_loop) { - HVX_Vector * vec_in1 = (HVX_Vector *) src; - HVX_Vector * vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src, hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); } } + + diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index d2d5d23636..22876e6dba 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) } #endif -static inline HVX_Vector hvx_vec_splat_fp32(float i) { +static inline HVX_Vector hvx_vec_splat_fp32(float v) { union { - float f; - int32_t i; - } fp32 = { .f = i }; + float f; + uint32_t i; + } fp32 = { .f = v }; return Q6_V_vsplat_R(fp32.i); } +static inline HVX_Vector hvx_vec_splat_fp16(float v) { + union { + __fp16 f; + uint16_t i; + } fp16 = { .f = v }; + + return Q6_Vh_vsplat_R(fp16.i); +} + static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) addr); @@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest } } +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; + HVX_UVector * restrict vsrc = (HVX_UVector *) src; + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + // bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { HVX_Vector * restrict vdst = (HVX_Vector *) dst; @@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3 return right_off <= chunk_size; } - - static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { HVX_VectorAlias u = { .v = v }; @@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { } static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { -#if __HTP_ARCH__ > 75 +#if __HVX_ARCH__ > 75 return Q6_Vsf_vfneg_Vsf(v); #else // neg by setting the fp32 sign bit HVX_Vector mask = Q6_V_vsplat_R(0x80000000); return Q6_V_vxor_VV(v, mask); -#endif // __HTP_ARCH__ > 75 +#endif // __HVX_ARCH__ > 75 } // ==================================================== @@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); } +static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_fp32(2.0f); + HVX_Vector one = hvx_vec_splat_fp32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int step_of_1 = num_elems >> 5; int remaining = num_elems - step_of_1 * VLEN_FP32; @@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr } } +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } +} + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } +} float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); void hvx_mul_f32(const uint8_t * restrict src0, @@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0, uint8_t * restrict dst, const int num_elems); void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale); void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fb5508a560..24b3e90e4b 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_get_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_matmul_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx, uint32_t n_bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - int write_idx = (n_bufs == 4) ? 3 : 2; + int write_idx = n_bufs - 1; // We had written to the output buffer, we'd also need to flush it rsp_bufs[0].fd = bufs[write_idx].fd; @@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_set_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_flash_attn_ext_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + // Setup Op context + struct htp_ops_context octx; + memset(&octx, 0, sizeof(octx)); + + octx.ctx = ctx; + octx.n_threads = ctx->n_threads; + + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.src3 = req->src3; + octx.src4 = req->src4; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + + int last_buf = 3; + + if (octx.src3.ne[0]) { + octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid + } + + if (octx.src4.ne[0]) { + octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + } + + octx.dst.data = (uint32_t) bufs[last_buf].ptr; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_flash_attn_ext(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + + struct dspqueue_buffer rsp_buf = bufs[last_buf]; + rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); +} + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; @@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { break; case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_rope_req(ctx, &req, bufs, n_bufs); break; + case HTP_OP_FLASH_ATTN_EXT: + if (!(n_bufs >= 4 && n_bufs <= 6)) { + FARF(ERROR, "Bad flash-attn-ext-req buffer list"); + continue; + } + proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_SET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad set-rows-req buffer list"); + continue; + } + proc_set_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_GET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad get-rows-req buffer list"); + continue; + } + proc_get_rows_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index f14523d485..9bb39db9fc 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -26,14 +26,14 @@ #include "hvx-utils.h" #include "ops-utils.h" +#define MM_SPAD_SRC0_NROWS 16 +#define MM_SPAD_SRC1_NROWS 16 +#define MM_SPAD_DST_NROWS 2 + struct htp_matmul_type { const char * type; void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy); + void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); }; typedef struct { @@ -907,145 +907,174 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); } -#if 1 -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - if (0) { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; +static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - for (uint32_t i = 0; i < n; i++) { - rsum += (float)vx[i] * vy[i]; - } - *s = rsum; - return; - } + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; - const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y; + HVX_Vector rsum = Q6_V_vsplat_R(0); - uint32_t nv0 = n / 64; // num full fp16 hvx vectors - uint32_t nv1 = n % 64; // leftover elements - - // for some reason we need volatile here so that the compiler doesn't try anything funky - volatile HVX_Vector rsum = Q6_V_vsplat_R(0); - float r_sum_scalar = 0.0f; uint32_t i = 0; - for (i = 0; i < nv0; i++) { - HVX_VectorPair yp = vy[i]; - - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - //NOTE: need volatile here to prevent compiler optimization - // Seem compiler cannot guarantee read-after-write?? - volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - - HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - if (nv1) { - // HVX_VectorPair yp = vy[i]; + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - // HVX_Vector x = vx[i]; - // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - // if (nv1 >= 32) { - // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); - // nv1 -= 32; - // } - - // rsum = hvx_vec_qf32_reduce_sum(rsum); - - // if (nv1) { - // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); - // } - - //process the remainder using scalar loop - rsum = hvx_vec_qf32_reduce_sum(rsum); - const __fp16 * restrict sx = (const __fp16 * restrict) x; - const float * restrict sy = (const float * restrict) y; - - for (uint32_t i = nv0 * 64; i < n; i++) { - r_sum_scalar += (float) sx[i] * sy[i]; - } - - // hvx_vec_dump_fp16("X", x); - // hvx_vec_dump_fp16("Y", y); - // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum)); - // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum)); - } else { - rsum = hvx_vec_qf32_reduce_sum(rsum); + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar; - -# ifdef HTP_DEBUG - { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; - - for (uint32_t i = 0; i < n; i++) { - rsum += vx[i] * vy[i]; - } - - float diff = fabs(*s - rsum); - if (diff > 0.001) { - FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s); - // htp_dump_f16("x", vx, n); - // htp_dump_f32("y", vy, n); - } - } -# endif + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); } -#else -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - const uint32_t fk = 64; - const uint32_t nb = n / fk; - assert(n % fk == 0); - assert(nb % 4 == 0); +static void vec_dot_f16_f16_aa_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; + const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - const uint32_t x_blk_size = 2 * fk; // fp16 - const uint32_t y_blk_size = 4 * fk; // fp32 + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; - // Row sum (qf32) HVX_Vector rsum0 = Q6_V_vsplat_R(0); HVX_Vector rsum1 = Q6_V_vsplat_R(0); - HVX_Vector rsum2 = Q6_V_vsplat_R(0); - HVX_Vector rsum3 = Q6_V_vsplat_R(0); - for (uint32_t i = 0; i < nb; i += 4) { - HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size)); - HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size)); + uint32_t i = 0; - HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]); - HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]); - HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]); - HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]); + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = y[i]; + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1))); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2))); - rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3))); + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - // Reduce and convert into fp32 - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3); - HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2)); - hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum)); -} -#endif + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); -#define htp_matmul_preamble \ + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + } + + rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_UVector * restrict x = (const HVX_UVector *) vx; + const HVX_UVector * restrict y = (const HVX_UVector *) vy; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +#define htp_matmul_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict src2 = &octx->src2; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -1056,6 +1085,11 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ const uint32_t ne0 = dst->ne[0]; \ const uint32_t ne1 = dst->ne[1]; \ const uint32_t ne2 = dst->ne[2]; \ @@ -1076,18 +1110,94 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -// q8x4 src1 tensor is already in VTCM spad -static void matmul(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +#define htp_matmul_preamble \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + +// *** matmul with support for 4d tensors and full broadcasting + +static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const uint32_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const uint32_t nr1 = ne1 * ne2 * ne3; + + // distribute the thread work across the inner or outer loop based on which one is larger + uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + // The number of elements in each chunk + const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + uint32_t current_chunk = ith; + + const uint32_t ith0 = current_chunk % nchunk0; + const uint32_t ith1 = current_chunk / nchunk0; + + const uint32_t ir0_start = dr0 * ith0; + const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); + + const uint32_t ir1_start = dr1 * ith1; + const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); + + // no work for this thread + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + // block-tiling attempt + const uint32_t blck_0 = 64; + const uint32_t blck_1 = 64; + + for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { + const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + + const uint32_t i1 = i11; + const uint32_t i2 = i12; + const uint32_t i3 = i13; + + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// src1 tensor is already in VTCM spad +static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1104,9 +1214,10 @@ static void matmul(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1124,11 +1235,11 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows @@ -1137,17 +1248,17 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1155,13 +1266,13 @@ static void matmul(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const int is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); } @@ -1176,17 +1287,7 @@ static void matmul(struct htp_matmul_type * mt, } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1202,9 +1303,10 @@ static void matvec(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1226,24 +1328,24 @@ static void matvec(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1251,8 +1353,8 @@ static void matvec(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } @@ -1274,22 +1376,13 @@ struct mmid_row_mapping { uint32_t i2; }; -// q8x4 src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict ids, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1340,7 +1433,7 @@ static void matmul_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1365,8 +1458,8 @@ static void matmul_id(struct htp_matmul_type * mt, } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1404,22 +1497,13 @@ static void matmul_id(struct htp_matmul_type * mt, dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4 src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict src2, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1464,7 +1548,7 @@ static void matvec_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1477,8 +1561,8 @@ static void matvec_id(struct htp_matmul_type * mt, mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1504,106 +1588,6 @@ static void matvec_id(struct htp_matmul_type * mt, dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// *** matmul in fp16 - -static void matmul_f16_f32(struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_matmul_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const uint32_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const uint32_t nr1 = ne1 * ne2 * ne3; - - // distribute the thread work across the inner or outer loop based on which one is larger - uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - // The number of elements in each chunk - const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - uint32_t current_chunk = ith; - - const uint32_t ith0 = current_chunk % nchunk0; - const uint32_t ith1 = current_chunk / nchunk0; - - const uint32_t ir0_start = dr0 * ith0; - const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); - - const uint32_t ir1_start = dr1 * ith1; - const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); - - // broadcast factors - const uint32_t r2 = ne12 / ne02; - const uint32_t r3 = ne13 / ne03; - - // no work for this thread - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - // block-tiling attempt - const uint32_t blck_0 = 64; - const uint32_t blck_1 = 64; - - __attribute__((aligned(128))) float tmp[64]; - - for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = (ir1 / (ne12 * ne1)); - const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const uint32_t i03 = i13 / r3; - const uint32_t i02 = i12 / r2; - - const uint32_t i1 = i11; - const uint32_t i2 = i12; - const uint32_t i3 = i13; - - const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); - const uint8_t * restrict src1_col = - (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); - float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { - // Use nb01 stride for non-contiguous src0 support - const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col); - } - - hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); - } - } - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - // *** dynamic quant static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { @@ -1780,20 +1764,14 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u for (uint32_t i = 0; i < nb; i++) { #if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #else #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" #endif @@ -1848,14 +1826,95 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_fp32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// TODO just a plain copy that should be done via the DMA during the Op setup +static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); } -// ** matmul callbacks for worker_pool +static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} -static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} + +// ** matmul/matvec callbacks for worker_pool + +static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1863,11 +1922,10 @@ static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1875,11 +1933,10 @@ static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1887,11 +1944,10 @@ static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1899,11 +1955,10 @@ static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1911,11 +1966,10 @@ static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1923,14 +1977,49 @@ static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matvec_2d(&mt, octx, n, i); +} + +static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matmul_2d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f32"; + mt.vec_dot = vec_dot_f16_f32_uu; + + matmul_4d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_uu; + + matmul_4d(&mt, octx, n, i); } // ** matmul-id callbacks for worker_pool @@ -1943,8 +2032,7 @@ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1955,8 +2043,7 @@ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1967,8 +2054,7 @@ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1979,8 +2065,7 @@ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1991,8 +2076,7 @@ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -2003,18 +2087,17 @@ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } // ** main matmul entry point -int op_matmul(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; +} - htp_matmul_preamble; +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; const char * op_type; @@ -2038,9 +2121,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q4x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2048,8 +2131,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2067,9 +2150,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q8x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q8x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q8x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2077,8 +2160,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2096,9 +2179,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "mxfp4x4x2-f32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2106,8 +2189,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2122,20 +2205,69 @@ int op_matmul(struct htp_ops_context * octx) { break; case HTP_TYPE_F16: - op_type = "f16-f32"; - quant_job_func = NULL; // htp_quantize_f32_f16; - matmul_job_func = htp_matmul_f16_f32; + { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - // For all tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256); + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - need_quant = false; + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + op_type = "f16-f16"; + quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_2d_f16_f16; + } else { + matmul_job_func = htp_matvec_2d_f16_f16; + } + + src1_row_size = f16_src1_row_size; // row size post quantization + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + op_type = "f16-f32"; + matmul_job_func = htp_matmul_4d_f16_f32; + } else { + op_type = "f16-f16"; + matmul_job_func = htp_matmul_4d_f16_f16; + } + + src1_row_size = nb11; // original row size in DDR + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + + need_quant = false; + } + } break; default: @@ -2166,6 +2298,9 @@ int op_matmul(struct htp_ops_context * octx) { octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; + if (need_quant) { // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -2185,12 +2320,9 @@ int op_matmul(struct htp_ops_context * octx) { // ** main matmul-id entry point int op_matmul_id(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * ids = &octx->src2; - struct htp_tensor * dst = &octx->dst; + htp_matmul_tensors_preamble; - htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; const char * op_type; @@ -2228,8 +2360,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2257,8 +2389,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2286,8 +2418,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c new file mode 100644 index 0000000000..bdd64fcc8f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -0,0 +1,168 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define set_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t ne1 = octx->dst.ne[1]; \ + \ + const uint32_t nr = ne01; + +static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + // copy row + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); +} + +static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_set_rows(struct htp_ops_context * octx) { + set_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->set_rows_div_ne12 = init_fastdiv_values(ne12); + octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + switch(octx->dst.type) { + case HTP_TYPE_F32: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + break; + case HTP_TYPE_F16: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 5bf0cbf792..80d249a22c 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, (const uint8_t *) mp_f32, slope); } else { - hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale); + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); if (mp_f32) { if (softmax_ctx->use_f16) { for (int i = 0; i < ne00; ++i) { @@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum); + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); } } } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index bb7557b025..8ed1e5b661 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void scale_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float scale = 0.f; + float bias = 0.f; + memcpy(&scale, &op_params[0], sizeof(float)); + memcpy(&bias, &op_params[1], sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + } + + hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + } +} + static void rms_norm_htp_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src, const float mean = sum / row_elems; const float scale = 1.0f / sqrtf(mean + epsilon); - hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale); + hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); } } } @@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_RMS_NORM: rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SCALE: + scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "rmsnorm-f32"; break; + case HTP_OP_SCALE: + unary_op_func = unary_job_dispatcher_f32; + op_type = "scale-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index b2e651e749..1a7d8c9fd6 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf" device="HTP0" [ "$D" != "" ] && device="$D" -verbose="" -[ "$V" != "" ] && verbose="$V" +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" opmask= [ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" @@ -34,7 +40,7 @@ adb $adbserial shell " \ cd $basedir; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ + $ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --batch-size 128 -ngl 99 $@ \ + --batch-size 128 -ngl 99 $cli_opts $@ \ "