hexagon: add fp16 support for binary ops: add,sub,mul,div (#20139)

* hexagon: add fp16 support for binary ops: add,sub,mul,div

* hexagon: fix test-backend-ops failures for fp16 binary ops on older arches (<v79)

* hexagon: decide on n_threads (aka n_jobs) early to avoid overallocating scratchpad

* snapdragon: fix readme link

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
YardenTal44 2026-03-06 04:29:13 +02:00 committed by GitHub
parent a0ed91a442
commit 2b10b62677
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 639 additions and 380 deletions

View File

@ -287,7 +287,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
| [Hexagon [In Progress]](docs/backend/snapdragon/README.md) | Snapdragon |
| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR |
## Obtaining and quantizing models

View File

@ -1865,15 +1865,26 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (dst->type != GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (dst->type != GGML_TYPE_F32) {
return false;
}
}
else if (src0->type == GGML_TYPE_F16) {
if (src1->type != GGML_TYPE_F16) {
return false;
}
if (dst->type != GGML_TYPE_F16) {
return false;
}
}
else {
return false;
}
if (!ggml_are_same_shape(src0, dst)) {
return false;
}

View File

@ -693,8 +693,8 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
size_t src0_row_size = src0->nb[1];
size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
@ -748,13 +748,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
uint32_t n_jobs = MIN(n_threads, src0_nrows);
// Prepare context
struct htp_act_context actx;
actx.octx = octx;
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
actx.src0_row_size = src0_row_size;
actx.src1_row_size = src1_row_size;
@ -794,7 +792,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
actx.data_src1 = data_src1;
actx.data_dst = (uint8_t *) dst->data;
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);
return HTP_STATUS_OK;
}

View File

@ -241,6 +241,9 @@ int op_argsort(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
const uint32_t n_threads = MIN(total_rows, octx->n_threads);
// Allocate scratchpad
// We need 1 row of float + 1 row of int32 per thread.
uint32_t ne00 = octx->src0.ne[0];
@ -251,7 +254,7 @@ int op_argsort(struct htp_ops_context * octx) {
// Make sure we round up to 256 for alignment requirements
spad_per_thread = hex_round_up(spad_per_thread, 256);
size_t total_spad_size = spad_per_thread * octx->n_threads;
size_t total_spad_size = spad_per_thread * n_threads;
if (octx->ctx->vtcm_size < total_spad_size) {
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
@ -267,15 +270,12 @@ int op_argsort(struct htp_ops_context * octx) {
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
octx->src0.data, octx->dst.data);
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
struct htp_argsort_context actx;
actx.octx = octx;
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;
// Run jobs
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);
return HTP_STATUS_OK;
}

View File

@ -95,43 +95,87 @@ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_
}
// Macro for scalar op switch
#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
default: break; \
#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
if(TYPE == HTP_TYPE_F32) { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
default: break; \
} \
} \
else { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
default: break; \
} \
}
// Macro for vector op switch (All Aligned)
#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
default: break; \
#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
if(TYPE == HTP_TYPE_F32) { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
default: break; \
} \
} \
else { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
default: break; \
} \
}
// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
default: break; \
#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
if(TYPE == HTP_TYPE_F32) { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
default: break; \
} \
} \
else { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
default: break; \
} \
}
// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
default: break; \
#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
if(TYPE == HTP_TYPE_F32) { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
default: break; \
} \
} \
else { \
switch (octx->op) { \
case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
default: break; \
} \
}
// 1. Scalar src1 (ne10 == 1)
@ -140,6 +184,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
@ -170,7 +216,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
@ -199,13 +245,12 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t r = 0; r < current_block_size; r++) {
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
float val = *(float *)src1_ptr;
COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
src1_ptr += s1_stride;
COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
}
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@ -216,7 +261,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
@ -230,6 +275,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
@ -268,8 +315,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
@ -284,7 +331,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint32_t i03, i02, i01, rem;
@ -293,7 +340,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@ -310,8 +357,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
@ -326,6 +373,8 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
@ -359,7 +408,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
@ -373,7 +422,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint32_t i03, i02, i01, rem;
@ -382,7 +431,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@ -392,7 +441,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
@ -406,6 +455,8 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
const uint32_t src0_type = octx->src0.type;
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
@ -435,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
@ -462,11 +513,11 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
// Read src1 from DDR (unaligned)
COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@ -476,7 +527,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
@ -490,6 +541,9 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
const uint32_t src0_type = octx->src0.type;
const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
@ -519,7 +573,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
@ -549,12 +603,12 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
for (uint32_t c = 0; c < ne00; c += ne10) {
uint32_t len = MIN(ne10, ne00 - c);
// Use UUU for speed and simplicity
COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
}
}
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@ -564,7 +618,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
@ -672,18 +726,20 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
dma_queue_flush(q);
}
static int execute_op_binary_f32(struct htp_ops_context * octx) {
static int execute_op_binary(struct htp_ops_context * octx) {
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
const uint32_t n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
// Use packed row sizes for VTCM allocation
const size_t src0_row_size = src0->ne[0] * sizeof(float);
const size_t src1_row_size = src1->ne[0] * sizeof(float);
const size_t dst_row_size = dst->ne[0] * sizeof(float);
const uint32_t src0_type = octx->src0.type;
const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
const size_t src0_row_size = src0->ne[0] * elem_size;
const size_t src1_row_size = src1->ne[0] * elem_size;
const size_t dst_row_size = dst->ne[0] * elem_size;
// Align to VLEN
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
@ -694,7 +750,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
bool is_scalar = !is_add_id && (src1->ne[0] == 1);
// Determine which kernel we will use to alloc memory and dispatch
bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
(src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
(src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
(src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
@ -726,7 +782,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
}
if (rows_per_buffer < 1) {
FARF(ERROR, "binary-f32: VTCM too small\n");
FARF(ERROR, "binary: VTCM too small\n");
return HTP_STATUS_VTCM_TOO_SMALL;
}
@ -761,16 +817,14 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
uint32_t n_jobs = MIN(n_threads, src0_nrows);
dma_queue * q = octx->ctx->dma[0];
if (is_row_bcast) {
dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
}
struct htp_binary_context bctx;
bctx.octx = octx;
bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bctx.block_max = rows_per_buffer;
bctx.src0_row_size_aligned = src0_row_size_aligned;
bctx.src1_row_size_aligned = src1_row_size_aligned;
@ -814,14 +868,24 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
dma_queue_pop(q);
}
worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
return HTP_STATUS_OK;
}
int op_binary(struct htp_ops_context * octx) {
if (octx->src0.type == HTP_TYPE_F32) {
return execute_op_binary_f32(octx);
// Does not support permutations of src1
const struct htp_tensor * src1 = &octx->src1;
if (src1->nb[1] < src1->nb[0]) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t src0_type = octx->src0.type;
if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
return execute_op_binary(octx);
}
return HTP_STATUS_NO_SUPPORT;
}

View File

@ -202,6 +202,8 @@ static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
int op_cpy(struct htp_ops_context * octx) {
cpy_preamble;
const uint32_t n_threads = MIN(nr, octx->n_threads);
struct htp_copy_context ct;
ct.octx = octx;
@ -227,8 +229,7 @@ int op_cpy(struct htp_ops_context * octx) {
const bool transposed = (nb00 > nb01) || (nb0 > nb1);
const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
if (sametype && sameshape) {
ct.copy = cpy_thread_sametype_sameshape;
@ -245,7 +246,7 @@ int op_cpy(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);
return HTP_STATUS_OK;
}

View File

@ -82,6 +82,8 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
int op_get_rows(struct htp_ops_context * octx) {
get_rows_preamble;
const uint32_t n_threads = MIN(nr, octx->n_threads);
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
@ -103,9 +105,8 @@ int op_get_rows(struct htp_ops_context * octx) {
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);
return HTP_STATUS_OK;
}

View File

@ -13,14 +13,15 @@
// Binary operations (add, mul, sub)
//
#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
#define UNUSED(x) (void)(x)
#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t epv = 128 / (elem_size); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
@ -32,62 +33,74 @@
} \
if (nloe) { \
HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
vec_store((void *) &vdst[i], nloe * (elem_size), v); \
} \
} while(0)
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b)
#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b)
#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b)
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16)
// Dispatcher logic
#define HVX_BINARY_DISPATCHER(OP_NAME) \
@ -115,6 +128,10 @@ HVX_BINARY_DISPATCHER(hvx_add_f32)
HVX_BINARY_DISPATCHER(hvx_sub_f32)
HVX_BINARY_DISPATCHER(hvx_mul_f32)
HVX_BINARY_DISPATCHER(hvx_add_f16)
HVX_BINARY_DISPATCHER(hvx_sub_f16)
HVX_BINARY_DISPATCHER(hvx_mul_f16)
// Mul-Mul Optimized
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
@ -136,26 +153,25 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re
_Pragma("unroll(4)")
for (; i < nvec; i++) {
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
}
if (nloe) {
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]);
hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
}
}
// Scalar Operations
#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t epv = 128 / (elem_size); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
@ -169,138 +185,88 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re
if (nloe) { \
HVX_Vector v = vsrc[i]; \
v = scalar_op_macro(v); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
vec_store((void *) &vdst[i], nloe * (elem_size), v); \
} \
} while(0)
#define HVX_OP_ADD_SCALAR(v) \
#define HVX_OP_ADD_SCALAR_F32(v) \
({ \
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
HVX_Vector out = HVX_OP_ADD(v, val_vec); \
HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \
Q6_V_vmux_QVV(pred_inf, inf, out); \
})
#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec)
#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec)
// Add Scalar Variants
#define HVX_OP_ADD_SCALAR_F16(v) \
({ \
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \
HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \
Q6_V_vmux_QVV(pred_inf, inf, out); \
})
static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec)
#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec)
// Scalar Variants
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \
static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
const HVX_Vector val_vec = SPLAT_MACRO(val); \
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src % 128 == 0); \
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
const HVX_Vector val_vec = SPLAT_MACRO(val); \
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
assert((uintptr_t) dst % 128 == 0); \
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
const HVX_Vector val_vec = SPLAT_MACRO(val); \
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
assert((uintptr_t) src % 128 == 0); \
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
const HVX_Vector val_vec = SPLAT_MACRO(val); \
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float)
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float)
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float)
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16)
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16)
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16)
// Dispatcher logic
#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
OP_NAME##_aa(dst, src, val, num_elems); \
} else if (hex_is_aligned((void *) dst, 128)) { \
OP_NAME##_au(dst, src, val, num_elems); \
} else if (hex_is_aligned((void *) src, 128)) { \
OP_NAME##_ua(dst, src, val, num_elems); \
} else { \
OP_NAME##_uu(dst, src, val, num_elems); \
} \
}
static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
}
HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float)
HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float)
HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float)
static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
}
static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
static const float kInf = INFINITY;
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
}
// Sub Scalar Variants
static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
}
static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
}
static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
}
static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
}
// Mul Scalar Variants
static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
}
static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
}
static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
}
static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
}
static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_add_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_add_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_add_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_add_scalar_f32_uu(dst, src, val, num_elems);
}
}
static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_mul_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
}
}
static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) dst, 128)) {
hvx_sub_scalar_f32_au(dst, src, val, num_elems);
} else if (hex_is_aligned((void *) src, 128)) {
hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
} else {
hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
}
}
HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16)
HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16)
HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16)
// MIN Scalar variants
@ -310,24 +276,24 @@ static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t *
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
@ -357,27 +323,27 @@ static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) dst % 128 == 0);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) src % 128 == 0);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
@ -396,7 +362,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
// Square
//
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
@ -410,10 +376,10 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \
} \
if (nloe) { \
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
@ -421,21 +387,21 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
@ -454,17 +420,24 @@ static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict
}
}
#undef HVX_OP_ADD
#undef HVX_OP_SUB
#undef HVX_OP_MUL
#undef HVX_OP_ADD_F32
#undef HVX_OP_SUB_F32
#undef HVX_OP_MUL_F32
#undef HVX_OP_ADD_F16
#undef HVX_OP_SUB_F16
#undef HVX_OP_MUL_F16
#undef hvx_arith_loop_body
#undef HVX_OP_ADD_SCALAR
#undef HVX_OP_SUB_SCALAR
#undef HVX_OP_MUL_SCALAR
#undef HVX_OP_ADD_SCALAR_F32
#undef HVX_OP_SUB_SCALAR_F32
#undef HVX_OP_MUL_SCALAR_F32
#undef HVX_OP_ADD_SCALAR_F16
#undef HVX_OP_SUB_SCALAR_F16
#undef HVX_OP_MUL_SCALAR_F16
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
#undef DEFINE_HVX_BINARY_OP_VARIANTS
#undef HVX_BINARY_DISPATCHER
#undef UNUSED
#endif // HVX_ARITH_H

View File

@ -189,4 +189,52 @@ static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vect
#endif
#if __HVX_ARCH__ < 79
static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
{
const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16
HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
}
static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
{
const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16
HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
}
static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
{
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
}
#else
static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
{
return Q6_Vhf_vadd_VhfVhf(a, b);
}
static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
{
return Q6_Vhf_vsub_VhfVhf(a, b);
}
static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
{
return Q6_Vhf_vmpy_VhfVhf(a, b);
}
#endif // __HVX_ARCH__ < 79
#endif /* HVX_BASE_H */

View File

@ -15,11 +15,144 @@
#include "hvx-arith.h"
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) {
#if __HVX_ARCH__ < 79
HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32));
HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32));
#else
HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32);
HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32);
#endif
HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const);
HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const);
#if __HVX_ARCH__ < 79
HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
#else
HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
#endif
return res;
}
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
\
const uint32_t nvec = n / VLEN_FP16; \
const uint32_t nloe = n % VLEN_FP16; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
} \
} while(0)
static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src % 128 == 0);
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
assert((uintptr_t) dst % 128 == 0);
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
assert((uintptr_t) src % 128 == 0);
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32.
static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
#if __HVX_ARCH__ < 79
// Convert first input to fp32
HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0
HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32));
HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32));
// Convert second input to fp32
HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0
HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32));
HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32));
#else
// Convert first input to fp32
HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0
HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32);
HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32);
// Convert second input to fp32
HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0
HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32);
HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32);
#endif
// Inverse second input in fp32
HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask);
HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask);
// Multiply first input by inverse of second, in fp32
HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0);
HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1);
// Convert back to fp16
#if __HVX_ARCH__ < 79
HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
#else
HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
#endif
return recip;
}
#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
\
const uint32_t nvec = n / VLEN_FP16; \
const uint32_t nloe = n % VLEN_FP16; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
} \
} while(0)
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
@ -36,81 +169,83 @@
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
} \
} while(0)
// 3-letter suffix variants
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \
} \
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \
} \
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \
} \
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \
} \
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \
} \
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \
} \
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src1 % 128 == 0); \
OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \
} \
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \
} \
// Dispatcher logic
#define HVX_DIV_DISPATCHER(OP_NAME) \
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
if (hex_is_aligned((void *) dst, 128)) { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
else OP_NAME##_aau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
else OP_NAME##_auu(dst, src0, src1, num_elems); \
} \
} else { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
else OP_NAME##_uau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
} \
} \
}
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body)
DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body)
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
}
HVX_DIV_DISPATCHER(hvx_div_f32)
HVX_DIV_DISPATCHER(hvx_div_f16)
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
else hvx_div_f32_aau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
else hvx_div_f32_auu(dst, src0, src1, num_elems);
}
} else {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
else hvx_div_f32_uau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
}
}
}
#undef HVX_OP_MUL
#undef HVX_OP_MUL_F32
#endif // HVX_DIV_H

View File

@ -137,40 +137,74 @@ static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector n
} \
} while(0)
static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
HVX_Vector out = hvx_vec_inverse_f16(v_sf);
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out);
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
}
static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
\
const uint32_t nvec = n / VLEN_FP16; \
const uint32_t nloe = n % VLEN_FP16; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
} \
if (nloe) { \
HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \
} \
} while(0)
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src % 128 == 0); \
OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \
} \
static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \
} \
static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
assert((uintptr_t) src % 128 == 0); \
OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \
} \
static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \
} \
// Dispatcher logic
#define HVX_INV_DISPATCHER(OP_NAME) \
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
OP_NAME##_aa(dst, src, num_elems); \
} else if (hex_is_aligned((void *) dst, 128)) { \
OP_NAME##_au(dst, src, num_elems); \
} else if (hex_is_aligned((void *) src, 128)) { \
OP_NAME##_ua(dst, src, num_elems); \
} else { \
OP_NAME##_uu(dst, src, num_elems); \
} \
}
static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body)
DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body)
static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
if ((unsigned long) dst % 128 == 0) {
if ((unsigned long) src % 128 == 0) {
hvx_inverse_f32_aa(dst, src, num_elems);
} else {
hvx_inverse_f32_au(dst, src, num_elems);
}
} else {
if ((unsigned long) src % 128 == 0) {
hvx_inverse_f32_ua(dst, src, num_elems);
} else {
hvx_inverse_f32_uu(dst, src, num_elems);
}
}
}
HVX_INV_DISPATCHER(hvx_inverse_f32)
HVX_INV_DISPATCHER(hvx_inverse_f16)
#endif // HVX_INVERSE_H

View File

@ -400,7 +400,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t n_threads = octx->n_threads;
const uint32_t ne0 = dst->ne[0];
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
@ -465,17 +467,14 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
rctx.dst_row_size_aligned = dst_row_size_aligned;
rctx.theta_cache_offset = theta_cache_size_aligned;
uint32_t ne0 = dst->ne[0];
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
rctx.src0_nrows = src0_nrows;
rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads);
}
return err;

View File

@ -128,6 +128,8 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
int op_set_rows(struct htp_ops_context * octx) {
set_rows_preamble;
const uint32_t n_threads = MIN(nr, octx->n_threads);
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
@ -149,15 +151,14 @@ int op_set_rows(struct htp_ops_context * octx) {
srctx.div_ne12 = init_fastdiv_values(ne12);
srctx.div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
switch(octx->dst.type) {
case HTP_TYPE_F32:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
break;
case HTP_TYPE_F16:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);
break;
default:
return HTP_STATUS_NO_SUPPORT;

View File

@ -353,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
const uint32_t n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t src1_row_size = src0_row_size;
@ -393,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
}
return err;

View File

@ -102,11 +102,9 @@ int op_sum_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
const int n_threads = octx->n_threads;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bool opt_path = false;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
@ -124,7 +122,7 @@ int op_sum_rows(struct htp_ops_context * octx) {
.opt_path = opt_path,
};
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads);
return HTP_STATUS_OK;
}

View File

@ -301,8 +301,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
const int n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
@ -338,11 +338,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
struct htp_unary_context uctx = {
.octx = octx,
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
.src0_nrows = src0_nrows,
.data_src0 = (const uint8_t *)src0->data,
@ -361,7 +359,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
.nc = src0->ne[0],
};
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
}
return err;