Hexagon add support for f16/f32 flash attention, scale, set-rows and improve f16/32 matmul (#18611)
* hexagon: improve fp16 matmul and add fp32/fp16 flash-attention * hexagon: add support for set-rows fp32 -> fp16 with i32/i64 row-idx * hexagon: add support for SCALE fp32 * hexagon: replace scalar fp32 -> fp16 copy with HVX * hexagon: optimize flash_atten_ext with aligned VTCM buffers and DMA - Implements double-buffered DMA prefetching for K, V, and Mask tensors. - Ensures K and V rows in VTCM are padded to 128 bytes to support aligned HVX operations. - Correctly synchronizes DMA transfers to prevent race conditions. - Uses `FLASH_ATTN_BLOCK_SIZE` of 128 for efficient chunking. * hexagon: use aligned mad_f16 * hexagon: flash_atten more aligned ops * hexagon: optimize scale_f32 hvx helpers * hexagon: unroll fa loops * hexagon: remove unused set-rows log * hexagon: flash_attn_ext add support for DMAing Q - Update `op_flash_attn_ext` to include Q row size in scratchpad allocation. - Pad Q row size to 128 bytes for alignment. - Implement DMA transfer for Q tensor in `flash_attn_ext_f16_thread`. - Update dot product computations to use VTCM-buffered Q data. * hexagon: fix handling of NANs hvx dotproducts * hexagon: cleanup spad allocation in flash-atten * hexagon: improve fp16/fp32 matmul - Introduced `vec_dot_f16_f16` and `vec_dot_f16_f16_rx2` kernels using efficient HVX dot product intrinsics. - Added `quantize_fp32_f16` to copy/convert weights from DDR to VTCM - Updated `op_matmul` to use the optimized path when VTCM capacity allows and broadcasting requirements are compatible. - Implemented fallback logic to the original implementation for complex broadcasting scenarios. * hexagon: fix HVX_ARCH check * hexagon: matmul cleanup and fp16 fixes Use aligned vec_dot_f16 for 2d matmuls and unaligned version for 4d. * hexagon: fix fp16 x fp16 matmuls and some minor refactoring * hexagon: add support for GET_ROWS f32 -> f32 Also optimize SET_ROWS threading a bit when we have just a few rows to process. * hexagon: optimize set-rows threading * hexagon: update adb/run-bench.sh to properly support experimental and verbose options * hexagon: flash_atten use aligned vectors for dot products
This commit is contained in:
parent
ccbc84a537
commit
95ea9e0861
|
|
@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * src3 = op->src[3];
|
||||
const struct ggml_tensor * src4 = op->src[4];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
// Check for F16 support only as requested
|
||||
if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src3 && src3->type != GGML_TYPE_F16) { // mask
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src4 && src4->type != GGML_TYPE_F32) { // sinks
|
||||
return false;
|
||||
}
|
||||
|
||||
// For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
|
||||
// but the op implementation writes to F16 or F32.
|
||||
// Let's assume dst can be F32 or F16.
|
||||
if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return opt_experimental;
|
||||
}
|
||||
|
||||
static bool hex_supported_src0_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
|
@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-cont tensors
|
||||
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
|||
return false; // typically the lm-head which would be too large for VTCM
|
||||
}
|
||||
|
||||
// if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
|
||||
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
|||
}
|
||||
break;
|
||||
|
||||
case GGML_TYPE_F16:
|
||||
if (!opt_experimental) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-cont tensors
|
||||
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * src1 = op->src[1]; // indices
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * src1 = op->src[1]; // indices
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const int32_t * op_params = &op->op_params[0];
|
||||
|
||||
|
|
@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t
|
|||
d->offset = (uint8_t *) t->data - buf->base;
|
||||
d->size = ggml_nbytes(t);
|
||||
|
||||
if (!d->size) {
|
||||
// Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
|
||||
d->size = 64;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
|
||||
// Flush CPU
|
||||
|
|
@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_GET_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
template <bool _is_src0_constant>
|
||||
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
|
|
@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer *
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_SET_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
|
||||
|
|
@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
|||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SCALE:
|
||||
req->op = HTP_OP_SCALE;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
|
|
@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs
|
|||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_FLASH_ATTN_EXT;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->name.c_str();
|
||||
|
|
@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
|
|
@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
|||
ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
|
|
@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
break;
|
||||
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
|
|
@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
|||
supp = ggml_hexagon_supported_rope(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
supp = ggml_hexagon_supported_set_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_GET_ROWS:
|
||||
supp = ggml_hexagon_supported_get_rows(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED
|
|||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
|
|
|||
|
|
@ -0,0 +1,566 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
|
||||
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * v (float)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S = hvx_vec_splat_fp16(s);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
// Multiply x * s -> pair of F32 vectors
|
||||
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
||||
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
|
||||
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
||||
|
||||
HVX_Vector xs = Q6_V_lo_W(xs_p);
|
||||
i = 2 * i; // index for ptr_y
|
||||
|
||||
if (nloe >= 32) {
|
||||
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define FLASH_ATTN_BLOCK_SIZE 128
|
||||
|
||||
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
const uint32_t neq2 = q->ne[2];
|
||||
const uint32_t neq3 = q->ne[3];
|
||||
|
||||
const uint32_t nek0 = k->ne[0];
|
||||
const uint32_t nek1 = k->ne[1];
|
||||
const uint32_t nek2 = k->ne[2];
|
||||
const uint32_t nek3 = k->ne[3];
|
||||
|
||||
const uint32_t nev0 = v->ne[0];
|
||||
const uint32_t nev1 = v->ne[1];
|
||||
const uint32_t nev2 = v->ne[2];
|
||||
const uint32_t nev3 = v->ne[3];
|
||||
|
||||
const uint32_t nbq1 = q->nb[1];
|
||||
const uint32_t nbq2 = q->nb[2];
|
||||
const uint32_t nbq3 = q->nb[3];
|
||||
|
||||
const uint32_t nbk1 = k->nb[1];
|
||||
const uint32_t nbk2 = k->nb[2];
|
||||
const uint32_t nbk3 = k->nb[3];
|
||||
|
||||
const uint32_t nbv1 = v->nb[1];
|
||||
const uint32_t nbv2 = v->nb[2];
|
||||
const uint32_t nbv3 = v->nb[3];
|
||||
|
||||
const uint32_t ne1 = dst->ne[1];
|
||||
const uint32_t ne2 = dst->ne[2];
|
||||
const uint32_t ne3 = dst->ne[3];
|
||||
|
||||
const uint32_t nb1 = dst->nb[1];
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// total rows in q
|
||||
const uint32_t nr = neq1*neq2*neq3;
|
||||
|
||||
const uint32_t dr = (nr + nth - 1) / nth;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (ir0 >= ir1) return;
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
const uint32_t DK = nek0;
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
|
||||
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
|
||||
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
||||
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
||||
|
||||
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
|
||||
|
||||
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
|
||||
|
||||
// Fetch Q row
|
||||
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
// Clear accumulator
|
||||
float * VKQ32 = (float *) spad_a;
|
||||
memset(VKQ32, 0, DV * sizeof(float));
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
|
||||
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
||||
}
|
||||
|
||||
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
// Prefetch first two blocks
|
||||
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// Wait for DMA
|
||||
uint8_t * k_base = dma_queue_pop(dma).dst; // K
|
||||
uint8_t * v_base = dma_queue_pop(dma).dst; // V
|
||||
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
|
||||
|
||||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
}
|
||||
|
||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_fp32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 3. Mask
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
|
||||
|
||||
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
|
||||
float m_block = hvx_vec_get_fp32(v_max);
|
||||
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
|
||||
float ms = expf(M_old - M_new);
|
||||
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
|
||||
float p_sum = hvx_vec_get_fp32(p_sum_vec);
|
||||
S += p_sum;
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (q->type == HTP_TYPE_F32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s_val = logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const float m_val = m_base[ic];
|
||||
s_val += slope * m_val;
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
ms = expf(Mold - M);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s_val - M);
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
if (ib + 2 < n_blocks) {
|
||||
const uint32_t next_ib = ib + 2;
|
||||
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
|
||||
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sinks
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
|
||||
|
||||
// Store result
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// dst is permuted
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
||||
|
||||
if (dst->type == HTP_TYPE_F32) {
|
||||
hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
} else if (dst->type == HTP_TYPE_F16) {
|
||||
hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
flash_attn_ext_f16_thread(octx, i, n);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
||||
k->type != HTP_TYPE_F16 ||
|
||||
v->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
|
||||
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
||||
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
|
||||
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
||||
|
||||
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
get_rows_preamble;
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = octx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
if (i01 >= ne01) {
|
||||
// invalid index, skip for now to avoid crash
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
get_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -11,11 +11,6 @@
|
|||
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
|
||||
// FIXME: move these into matmul-ops
|
||||
#define HTP_SPAD_SRC0_NROWS 16
|
||||
#define HTP_SPAD_SRC1_NROWS 16
|
||||
#define HTP_SPAD_DST_NROWS 2
|
||||
|
||||
// Main context for htp DSP backend
|
||||
struct htp_context {
|
||||
dspqueue_t queue;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ enum htp_data_type {
|
|||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
|
@ -57,6 +59,10 @@ enum htp_op {
|
|||
HTP_OP_SOFTMAX = 11,
|
||||
HTP_OP_ADD_ID = 12,
|
||||
HTP_OP_ROPE = 13,
|
||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
||||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
INVALID
|
||||
};
|
||||
|
||||
|
|
@ -137,6 +143,8 @@ struct htp_general_req {
|
|||
struct htp_tensor src0; // Input0 tensor
|
||||
struct htp_tensor src1; // Input1 tensor
|
||||
struct htp_tensor src2; // Input2 tensor
|
||||
struct htp_tensor src3; // Input3 tensor
|
||||
struct htp_tensor src4; // Input4 tensor
|
||||
struct htp_tensor dst; // Output tensor
|
||||
|
||||
// should be multiple of 64 bytes (cacheline)
|
||||
|
|
@ -152,6 +160,6 @@ struct htp_general_rsp {
|
|||
};
|
||||
|
||||
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||
#define HTP_MAX_PACKET_BUFFERS 4
|
||||
#define HTP_MAX_PACKET_BUFFERS 8
|
||||
|
||||
#endif /* HTP_MSG_H */
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
struct htp_spad {
|
||||
uint8_t * data;
|
||||
size_t stride;
|
||||
size_t size;
|
||||
size_t size_per_thread;
|
||||
};
|
||||
|
|
@ -26,11 +27,14 @@ struct htp_ops_context {
|
|||
struct htp_tensor src0;
|
||||
struct htp_tensor src1;
|
||||
struct htp_tensor src2;
|
||||
struct htp_tensor src3;
|
||||
struct htp_tensor src4;
|
||||
struct htp_tensor dst;
|
||||
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
worker_pool_context_t * wpool; // worker pool
|
||||
|
|
@ -49,6 +53,27 @@ struct htp_ops_context {
|
|||
struct fastdiv_values src1_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src3_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src3_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src3_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx);
|
|||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
|
|
|||
|
|
@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
|
|||
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
|
@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
|||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
|||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float v) {
|
||||
union {
|
||||
float f;
|
||||
int32_t i;
|
||||
} fp32 = { .f = i };
|
||||
float f;
|
||||
uint32_t i;
|
||||
} fp32 = { .f = v };
|
||||
|
||||
return Q6_V_vsplat_R(fp32.i);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp16(float v) {
|
||||
union {
|
||||
__fp16 f;
|
||||
uint16_t i;
|
||||
} fp16 = { .f = v };
|
||||
|
||||
return Q6_Vh_vsplat_R(fp16.i);
|
||||
}
|
||||
|
||||
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
|
||||
|
|
@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest
|
|||
}
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is unaligned, destination unaligned
|
||||
static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
|
||||
static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
|
||||
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
|
||||
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
|
||||
}
|
||||
}
|
||||
|
||||
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
|
||||
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
|
@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
|
|||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
HVX_VectorAlias u = { .v = v };
|
||||
|
||||
|
|
@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
|||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
|
||||
#if __HTP_ARCH__ > 75
|
||||
#if __HVX_ARCH__ > 75
|
||||
return Q6_Vsf_vfneg_Vsf(v);
|
||||
#else
|
||||
// neg by setting the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||
return Q6_V_vxor_VV(v, mask);
|
||||
#endif // __HTP_ARCH__ > 75
|
||||
#endif // __HVX_ARCH__ > 75
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
|
|
@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
|
|||
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
|
||||
// tanh(x) = 2 * sigmoid(2x) - 1
|
||||
HVX_Vector two = hvx_vec_splat_fp32(2.0f);
|
||||
HVX_Vector one = hvx_vec_splat_fp32(1.0f);
|
||||
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
|
||||
|
||||
static const float kMinExp = -87.f; // 0
|
||||
static const float kMaxExp = 87.f; // 1
|
||||
HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
|
||||
HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
|
||||
|
||||
HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
|
||||
|
||||
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
|
||||
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
|
||||
return Q6_Vsf_equals_Vqf32(res);
|
||||
}
|
||||
|
||||
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int step_of_1 = num_elems >> 5;
|
||||
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
||||
|
|
@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr
|
|||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
|
||||
HVX_Vector * vsrc = (HVX_Vector *) src;
|
||||
HVX_Vector * vdst = (HVX_Vector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
|
||||
HVX_UVector * vsrc = (HVX_UVector *) src;
|
||||
HVX_UVector * vdst = (HVX_UVector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
|
||||
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
|
||||
hvx_scale_f32_aa(dst, src, n, scale);
|
||||
} else {
|
||||
hvx_scale_f32_uu(dst, src, n, scale);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector vo = hvx_vec_splat_fp32(offset);
|
||||
|
||||
HVX_Vector * vsrc = (HVX_Vector *) src;
|
||||
HVX_Vector * vdst = (HVX_Vector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
int nvec = n / VLEN_FP32;
|
||||
int nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector vs = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector vo = hvx_vec_splat_fp32(offset);
|
||||
|
||||
HVX_UVector * vsrc = (HVX_UVector *) src;
|
||||
HVX_UVector * vdst = (HVX_UVector *) dst;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
|
||||
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
|
||||
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
|
||||
} else {
|
||||
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
|
||||
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||
|
|
@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
|||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
|
||||
|
|
|
|||
|
|
@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_get_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
|
|
@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx,
|
|||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
int write_idx = (n_bufs == 4) ? 3 : 2;
|
||||
int write_idx = n_bufs - 1;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
|
|
@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx,
|
|||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_set_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_flash_attn_ext_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx;
|
||||
memset(&octx, 0, sizeof(octx));
|
||||
|
||||
octx.ctx = ctx;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.src2 = req->src2;
|
||||
octx.src3 = req->src3;
|
||||
octx.src4 = req->src4;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||
|
||||
int last_buf = 3;
|
||||
|
||||
if (octx.src3.ne[0]) {
|
||||
octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
|
||||
}
|
||||
|
||||
if (octx.src4.ne[0]) {
|
||||
octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
|
||||
}
|
||||
|
||||
octx.dst.data = (uint32_t) bufs[last_buf].ptr;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_flash_attn_ext(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
|
||||
struct dspqueue_buffer rsp_buf = bufs[last_buf];
|
||||
rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
struct htp_context * ctx = (struct htp_context *) context;
|
||||
|
||||
|
|
@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
break;
|
||||
|
||||
case HTP_OP_RMS_NORM:
|
||||
case HTP_OP_SCALE:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
|
|
@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
|||
proc_rope_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_FLASH_ATTN_EXT:
|
||||
if (!(n_bufs >= 4 && n_bufs <= 6)) {
|
||||
FARF(ERROR, "Bad flash-attn-ext-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SET_ROWS:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad set-rows-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_set_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_GET_ROWS:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad get-rows-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_get_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,168 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne1 = octx->dst.ne[1]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
// ignore invalid indices
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
// copy row
|
||||
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
// ignore invalid indices
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
set_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
|
||||
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
|
|
@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
|
|||
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void scale_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
float scale = 0.f;
|
||||
float bias = 0.f;
|
||||
memcpy(&scale, &op_params[0], sizeof(float));
|
||||
memcpy(&bias, &op_params[1], sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
|
|
@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src,
|
|||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
|
||||
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
|||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
|
@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
|
|
|||
|
|
@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
|||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
verbose=""
|
||||
[ "$V" != "" ] && verbose="$V"
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
|
||||
|
||||
opmask=
|
||||
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||
|
|
@ -34,7 +40,7 @@ adb $adbserial shell " \
|
|||
cd $basedir; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
$ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--batch-size 128 -ngl 99 $@ \
|
||||
--batch-size 128 -ngl 99 $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
Loading…
Reference in New Issue