hexagon: add fp16 support for binary ops: add,sub,mul,div (#20139)
* hexagon: add fp16 support for binary ops: add,sub,mul,div * hexagon: fix test-backend-ops failures for fp16 binary ops on older arches (<v79) * hexagon: decide on n_threads (aka n_jobs) early to avoid overallocating scratchpad * snapdragon: fix readme link --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
parent
a0ed91a442
commit
2b10b62677
|
|
@ -287,7 +287,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
|
||||
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
|
||||
| [Hexagon [In Progress]](docs/backend/snapdragon/README.md) | Snapdragon |
|
||||
| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR |
|
||||
|
||||
## Obtaining and quantizing models
|
||||
|
|
|
|||
|
|
@ -1865,15 +1865,26 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
|||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (src0->type == GGML_TYPE_F16) {
|
||||
if (src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (dst->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -693,8 +693,8 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
size_t src0_row_size = src0->nb[1];
|
||||
size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
|
||||
|
|
@ -748,13 +748,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
// Prepare context
|
||||
struct htp_act_context actx;
|
||||
actx.octx = octx;
|
||||
|
||||
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
|
||||
actx.src0_row_size = src0_row_size;
|
||||
actx.src1_row_size = src1_row_size;
|
||||
|
|
@ -794,7 +792,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
|||
actx.data_src1 = data_src1;
|
||||
actx.data_dst = (uint8_t *) dst->data;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -241,6 +241,9 @@ int op_argsort(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
const uint32_t n_threads = MIN(total_rows, octx->n_threads);
|
||||
|
||||
// Allocate scratchpad
|
||||
// We need 1 row of float + 1 row of int32 per thread.
|
||||
uint32_t ne00 = octx->src0.ne[0];
|
||||
|
|
@ -251,7 +254,7 @@ int op_argsort(struct htp_ops_context * octx) {
|
|||
// Make sure we round up to 256 for alignment requirements
|
||||
spad_per_thread = hex_round_up(spad_per_thread, 256);
|
||||
|
||||
size_t total_spad_size = spad_per_thread * octx->n_threads;
|
||||
size_t total_spad_size = spad_per_thread * n_threads;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad_size) {
|
||||
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
|
||||
|
|
@ -267,15 +270,12 @@ int op_argsort(struct htp_ops_context * octx) {
|
|||
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
|
||||
octx->src0.data, octx->dst.data);
|
||||
|
||||
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
|
||||
|
||||
struct htp_argsort_context actx;
|
||||
actx.octx = octx;
|
||||
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
|
||||
actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;
|
||||
|
||||
// Run jobs
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,43 +95,87 @@ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_
|
|||
}
|
||||
|
||||
// Macro for scalar op switch
|
||||
#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
|
||||
case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
|
||||
default: break; \
|
||||
#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
|
||||
if(TYPE == HTP_TYPE_F32) { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
|
||||
case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Macro for vector op switch (All Aligned)
|
||||
#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
|
||||
if(TYPE == HTP_TYPE_F32) { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
|
||||
#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
|
||||
if(TYPE == HTP_TYPE_F32) { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
|
||||
#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
|
||||
if(TYPE == HTP_TYPE_F32) { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
switch (octx->op) { \
|
||||
case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
|
||||
case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
|
||||
default: break; \
|
||||
} \
|
||||
}
|
||||
|
||||
// 1. Scalar src1 (ne10 == 1)
|
||||
|
|
@ -140,6 +184,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
|
|||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
|
||||
|
|
@ -170,7 +216,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
|
|||
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
|
||||
ir_prefetch += current_block_size;
|
||||
spad_idx ^= 1;
|
||||
}
|
||||
|
|
@ -199,13 +245,12 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
|
|||
for (uint32_t r = 0; r < current_block_size; r++) {
|
||||
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
|
||||
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
|
||||
float val = *(float *)src1_ptr;
|
||||
COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
|
||||
src1_ptr += s1_stride;
|
||||
COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
|
||||
}
|
||||
|
||||
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
|
||||
|
||||
if (ir_prefetch < end_row) {
|
||||
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
|
||||
|
|
@ -216,7 +261,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
|
|||
p01 = prem - p02 * ne01;
|
||||
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
|
||||
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
|
||||
ir_prefetch += next_block_size;
|
||||
}
|
||||
ir += current_block_size;
|
||||
|
|
@ -230,6 +275,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
|||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
|
||||
|
|
@ -268,8 +315,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
|||
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
|
||||
ir_prefetch += current_block_size;
|
||||
spad_idx ^= 1;
|
||||
}
|
||||
|
|
@ -284,7 +331,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
|||
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
|
||||
uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
|
||||
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
|
||||
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
|
||||
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
|
||||
}
|
||||
|
||||
uint32_t i03, i02, i01, rem;
|
||||
|
|
@ -293,7 +340,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
|||
i02 = fastdiv(rem, &bctx->dim1_div);
|
||||
i01 = rem - i02 * ne01;
|
||||
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
|
||||
|
||||
if (ir_prefetch < end_row) {
|
||||
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
|
||||
|
|
@ -310,8 +357,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
|||
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
|
||||
uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
|
||||
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
|
||||
|
||||
ir_prefetch += next_block_size;
|
||||
}
|
||||
|
|
@ -326,6 +373,8 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
|||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
|
||||
|
|
@ -359,7 +408,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
|||
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
|
||||
ir_prefetch += current_block_size;
|
||||
spad_idx ^= 1;
|
||||
}
|
||||
|
|
@ -373,7 +422,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
|||
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
|
||||
uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
|
||||
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
|
||||
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
|
||||
COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
|
||||
}
|
||||
|
||||
uint32_t i03, i02, i01, rem;
|
||||
|
|
@ -382,7 +431,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
|||
i02 = fastdiv(rem, &bctx->dim1_div);
|
||||
i01 = rem - i02 * ne01;
|
||||
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
|
||||
|
||||
if (ir_prefetch < end_row) {
|
||||
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
|
||||
|
|
@ -392,7 +441,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
|||
p02 = fastdiv(prem, &bctx->dim1_div);
|
||||
p01 = prem - p02 * ne01;
|
||||
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
|
||||
ir_prefetch += next_block_size;
|
||||
}
|
||||
ir += current_block_size;
|
||||
|
|
@ -406,6 +455,8 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
|
|||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
|
||||
|
|
@ -435,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
|
|||
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
|
||||
ir_prefetch += current_block_size;
|
||||
spad_idx ^= 1;
|
||||
}
|
||||
|
|
@ -462,11 +513,11 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
|
|||
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
|
||||
|
||||
// Read src1 from DDR (unaligned)
|
||||
COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
|
||||
COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
|
||||
}
|
||||
|
||||
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
|
||||
|
||||
if (ir_prefetch < end_row) {
|
||||
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
|
||||
|
|
@ -476,7 +527,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
|
|||
p02 = fastdiv(prem, &bctx->dim1_div);
|
||||
p01 = prem - p02 * ne01;
|
||||
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
|
||||
ir_prefetch += next_block_size;
|
||||
}
|
||||
ir += current_block_size;
|
||||
|
|
@ -490,6 +541,9 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
|
|||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
|
||||
const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
|
||||
|
|
@ -519,7 +573,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
|
|||
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
|
||||
ir_prefetch += current_block_size;
|
||||
spad_idx ^= 1;
|
||||
}
|
||||
|
|
@ -549,12 +603,12 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
|
|||
for (uint32_t c = 0; c < ne00; c += ne10) {
|
||||
uint32_t len = MIN(ne10, ne00 - c);
|
||||
// Use UUU for speed and simplicity
|
||||
COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
|
||||
COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
|
||||
|
||||
if (ir_prefetch < end_row) {
|
||||
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
|
||||
|
|
@ -564,7 +618,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
|
|||
p02 = fastdiv(prem, &bctx->dim1_div);
|
||||
p01 = prem - p02 * ne01;
|
||||
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
|
||||
dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
|
||||
ir_prefetch += next_block_size;
|
||||
}
|
||||
ir += current_block_size;
|
||||
|
|
@ -672,18 +726,20 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
|
|||
dma_queue_flush(q);
|
||||
}
|
||||
|
||||
static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
||||
static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
// Use packed row sizes for VTCM allocation
|
||||
const size_t src0_row_size = src0->ne[0] * sizeof(float);
|
||||
const size_t src1_row_size = src1->ne[0] * sizeof(float);
|
||||
const size_t dst_row_size = dst->ne[0] * sizeof(float);
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
|
||||
const size_t src0_row_size = src0->ne[0] * elem_size;
|
||||
const size_t src1_row_size = src1->ne[0] * elem_size;
|
||||
const size_t dst_row_size = dst->ne[0] * elem_size;
|
||||
|
||||
// Align to VLEN
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
|
|
@ -694,7 +750,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
|||
bool is_scalar = !is_add_id && (src1->ne[0] == 1);
|
||||
|
||||
// Determine which kernel we will use to alloc memory and dispatch
|
||||
bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
|
||||
bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
|
||||
(src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
|
||||
(src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
|
||||
(src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
|
||||
|
|
@ -726,7 +782,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
|||
}
|
||||
|
||||
if (rows_per_buffer < 1) {
|
||||
FARF(ERROR, "binary-f32: VTCM too small\n");
|
||||
FARF(ERROR, "binary: VTCM too small\n");
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
|
|
@ -761,16 +817,14 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
dma_queue * q = octx->ctx->dma[0];
|
||||
if (is_row_bcast) {
|
||||
dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
|
||||
dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
|
||||
}
|
||||
|
||||
struct htp_binary_context bctx;
|
||||
bctx.octx = octx;
|
||||
bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
bctx.block_max = rows_per_buffer;
|
||||
bctx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
bctx.src1_row_size_aligned = src1_row_size_aligned;
|
||||
|
|
@ -814,14 +868,24 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
|||
dma_queue_pop(q);
|
||||
}
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_binary(struct htp_ops_context * octx) {
|
||||
if (octx->src0.type == HTP_TYPE_F32) {
|
||||
return execute_op_binary_f32(octx);
|
||||
|
||||
// Does not support permutations of src1
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
if (src1->nb[1] < src1->nb[0]) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
|
||||
return execute_op_binary(octx);
|
||||
}
|
||||
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -202,6 +202,8 @@ static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
|
|||
int op_cpy(struct htp_ops_context * octx) {
|
||||
cpy_preamble;
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
struct htp_copy_context ct;
|
||||
ct.octx = octx;
|
||||
|
||||
|
|
@ -227,8 +229,7 @@ int op_cpy(struct htp_ops_context * octx) {
|
|||
const bool transposed = (nb00 > nb01) || (nb0 > nb1);
|
||||
const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
if (sametype && sameshape) {
|
||||
ct.copy = cpy_thread_sametype_sameshape;
|
||||
|
|
@ -245,7 +246,7 @@ int op_cpy(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,6 +82,8 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
|||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
get_rows_preamble;
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
|
@ -103,9 +105,8 @@ int op_get_rows(struct htp_ops_context * octx) {
|
|||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,14 +13,15 @@
|
|||
// Binary operations (add, mul, sub)
|
||||
//
|
||||
|
||||
#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t epv = 128 / (elem_size); \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
|
|
@ -32,62 +33,74 @@
|
|||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
vec_store((void *) &vdst[i], nloe * (elem_size), v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
|
||||
#else
|
||||
#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
|
||||
#endif
|
||||
|
||||
#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b)
|
||||
#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b)
|
||||
#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b)
|
||||
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
|
||||
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \
|
||||
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float)
|
||||
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16)
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_BINARY_DISPATCHER(OP_NAME) \
|
||||
|
|
@ -115,6 +128,10 @@ HVX_BINARY_DISPATCHER(hvx_add_f32)
|
|||
HVX_BINARY_DISPATCHER(hvx_sub_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_mul_f32)
|
||||
|
||||
HVX_BINARY_DISPATCHER(hvx_add_f16)
|
||||
HVX_BINARY_DISPATCHER(hvx_sub_f16)
|
||||
HVX_BINARY_DISPATCHER(hvx_mul_f16)
|
||||
|
||||
// Mul-Mul Optimized
|
||||
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
|
@ -136,26 +153,25 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re
|
|||
|
||||
_Pragma("unroll(4)")
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
|
||||
HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
|
||||
vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
|
||||
HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
|
||||
HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
|
||||
HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]);
|
||||
hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar Operations
|
||||
|
||||
#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
|
||||
#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t epv = 128 / (elem_size); \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
|
|
@ -169,138 +185,88 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re
|
|||
if (nloe) { \
|
||||
HVX_Vector v = vsrc[i]; \
|
||||
v = scalar_op_macro(v); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
vec_store((void *) &vdst[i], nloe * (elem_size), v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define HVX_OP_ADD_SCALAR(v) \
|
||||
#define HVX_OP_ADD_SCALAR_F32(v) \
|
||||
({ \
|
||||
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
|
||||
HVX_Vector out = HVX_OP_ADD(v, val_vec); \
|
||||
HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \
|
||||
Q6_V_vmux_QVV(pred_inf, inf, out); \
|
||||
})
|
||||
|
||||
#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
|
||||
#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
|
||||
#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec)
|
||||
#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec)
|
||||
|
||||
// Add Scalar Variants
|
||||
#define HVX_OP_ADD_SCALAR_F16(v) \
|
||||
({ \
|
||||
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \
|
||||
HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \
|
||||
Q6_V_vmux_QVV(pred_inf, inf, out); \
|
||||
})
|
||||
|
||||
static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
|
||||
#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec)
|
||||
#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec)
|
||||
|
||||
// Scalar Variants
|
||||
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \
|
||||
static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
|
||||
const HVX_Vector val_vec = SPLAT_MACRO(val); \
|
||||
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src % 128 == 0); \
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
|
||||
const HVX_Vector val_vec = SPLAT_MACRO(val); \
|
||||
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
|
||||
const HVX_Vector val_vec = SPLAT_MACRO(val); \
|
||||
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
|
||||
assert((uintptr_t) src % 128 == 0); \
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
|
||||
const HVX_Vector val_vec = SPLAT_MACRO(val); \
|
||||
const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
|
||||
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float)
|
||||
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float)
|
||||
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float)
|
||||
|
||||
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16)
|
||||
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16)
|
||||
DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16)
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \
|
||||
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
|
||||
OP_NAME##_aa(dst, src, val, num_elems); \
|
||||
} else if (hex_is_aligned((void *) dst, 128)) { \
|
||||
OP_NAME##_au(dst, src, val, num_elems); \
|
||||
} else if (hex_is_aligned((void *) src, 128)) { \
|
||||
OP_NAME##_ua(dst, src, val, num_elems); \
|
||||
} else { \
|
||||
OP_NAME##_uu(dst, src, val, num_elems); \
|
||||
} \
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float)
|
||||
HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float)
|
||||
HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float)
|
||||
|
||||
static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
static const float kInf = INFINITY;
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
|
||||
}
|
||||
|
||||
// Sub Scalar Variants
|
||||
|
||||
static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
|
||||
}
|
||||
|
||||
// Mul Scalar Variants
|
||||
|
||||
static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_add_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_add_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_add_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_add_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_mul_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) dst, 128)) {
|
||||
hvx_sub_scalar_f32_au(dst, src, val, num_elems);
|
||||
} else if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
|
||||
} else {
|
||||
hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
|
||||
}
|
||||
}
|
||||
HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16)
|
||||
HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16)
|
||||
HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16)
|
||||
|
||||
// MIN Scalar variants
|
||||
|
||||
|
|
@ -310,24 +276,24 @@ static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t *
|
|||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
|
||||
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
|
||||
|
|
@ -357,27 +323,27 @@ static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t
|
|||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
|
||||
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
|
||||
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
|
||||
hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
|
||||
}
|
||||
|
||||
static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
|
||||
|
|
@ -396,7 +362,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
|||
// Square
|
||||
//
|
||||
|
||||
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
|
||||
#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
|
|
@ -410,10 +376,10 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
|||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
|
@ -421,21 +387,21 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
|||
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
|
||||
|
|
@ -454,17 +420,24 @@ static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict
|
|||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD
|
||||
#undef HVX_OP_SUB
|
||||
#undef HVX_OP_MUL
|
||||
#undef HVX_OP_ADD_F32
|
||||
#undef HVX_OP_SUB_F32
|
||||
#undef HVX_OP_MUL_F32
|
||||
#undef HVX_OP_ADD_F16
|
||||
#undef HVX_OP_SUB_F16
|
||||
#undef HVX_OP_MUL_F16
|
||||
#undef hvx_arith_loop_body
|
||||
#undef HVX_OP_ADD_SCALAR
|
||||
#undef HVX_OP_SUB_SCALAR
|
||||
#undef HVX_OP_MUL_SCALAR
|
||||
#undef HVX_OP_ADD_SCALAR_F32
|
||||
#undef HVX_OP_SUB_SCALAR_F32
|
||||
#undef HVX_OP_MUL_SCALAR_F32
|
||||
#undef HVX_OP_ADD_SCALAR_F16
|
||||
#undef HVX_OP_SUB_SCALAR_F16
|
||||
#undef HVX_OP_MUL_SCALAR_F16
|
||||
#undef hvx_scalar_loop_body
|
||||
#undef HVX_OP_MIN_SCALAR
|
||||
#undef HVX_OP_CLAMP_SCALAR
|
||||
#undef DEFINE_HVX_BINARY_OP_VARIANTS
|
||||
#undef HVX_BINARY_DISPATCHER
|
||||
#undef UNUSED
|
||||
|
||||
#endif // HVX_ARITH_H
|
||||
|
|
|
|||
|
|
@ -189,4 +189,52 @@ static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vect
|
|||
|
||||
#endif
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
|
||||
static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
{
|
||||
const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
|
||||
const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16
|
||||
HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
|
||||
HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
|
||||
HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
|
||||
HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
|
||||
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
{
|
||||
const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
|
||||
const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16
|
||||
HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
|
||||
HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
|
||||
HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
|
||||
HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
|
||||
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
{
|
||||
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
{
|
||||
return Q6_Vhf_vadd_VhfVhf(a, b);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
{
|
||||
return Q6_Vhf_vsub_VhfVhf(a, b);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
{
|
||||
return Q6_Vhf_vmpy_VhfVhf(a, b);
|
||||
}
|
||||
|
||||
#endif // __HVX_ARCH__ < 79
|
||||
|
||||
#endif /* HVX_BASE_H */
|
||||
|
|
|
|||
|
|
@ -15,11 +15,144 @@
|
|||
#include "hvx-arith.h"
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
|
||||
static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) {
|
||||
#if __HVX_ARCH__ < 79
|
||||
HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
|
||||
HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32));
|
||||
HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32));
|
||||
#else
|
||||
HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
|
||||
HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32);
|
||||
HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32);
|
||||
#endif
|
||||
|
||||
HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const);
|
||||
HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const);
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
|
||||
#else
|
||||
HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32.
|
||||
static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
|
||||
#if __HVX_ARCH__ < 79
|
||||
// Convert first input to fp32
|
||||
HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0
|
||||
HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32));
|
||||
HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32));
|
||||
|
||||
// Convert second input to fp32
|
||||
HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0
|
||||
HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32));
|
||||
HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32));
|
||||
#else
|
||||
// Convert first input to fp32
|
||||
HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0
|
||||
HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32);
|
||||
HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32);
|
||||
|
||||
// Convert second input to fp32
|
||||
HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0
|
||||
HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32);
|
||||
HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32);
|
||||
#endif
|
||||
|
||||
// Inverse second input in fp32
|
||||
HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask);
|
||||
HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask);
|
||||
|
||||
// Multiply first input by inverse of second, in fp32
|
||||
HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0);
|
||||
HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1);
|
||||
|
||||
// Convert back to fp16
|
||||
#if __HVX_ARCH__ < 79
|
||||
HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
|
||||
#else
|
||||
HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
|
||||
#endif
|
||||
|
||||
return recip;
|
||||
}
|
||||
|
||||
#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
|
|
@ -36,81 +169,83 @@
|
|||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// 3-letter suffix variants
|
||||
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
|
||||
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \
|
||||
} \
|
||||
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \
|
||||
} \
|
||||
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \
|
||||
} \
|
||||
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \
|
||||
} \
|
||||
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \
|
||||
} \
|
||||
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \
|
||||
} \
|
||||
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \
|
||||
} \
|
||||
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \
|
||||
} \
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_DIV_DISPATCHER(OP_NAME) \
|
||||
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
|
||||
if (hex_is_aligned((void *) dst, 128)) { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_aau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_auu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body)
|
||||
DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body)
|
||||
|
||||
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
HVX_DIV_DISPATCHER(hvx_div_f32)
|
||||
HVX_DIV_DISPATCHER(hvx_div_f16)
|
||||
|
||||
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_aau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_auu(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_MUL
|
||||
#undef HVX_OP_MUL_F32
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
|
|
|
|||
|
|
@ -137,40 +137,74 @@ static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector n
|
|||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
|
||||
HVX_Vector out = hvx_vec_inverse_f16(v_sf);
|
||||
|
||||
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
|
||||
const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out);
|
||||
|
||||
return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
|
||||
static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \
|
||||
} \
|
||||
static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \
|
||||
} \
|
||||
static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
|
||||
assert((uintptr_t) src % 128 == 0); \
|
||||
OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \
|
||||
} \
|
||||
static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
|
||||
OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \
|
||||
} \
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_INV_DISPATCHER(OP_NAME) \
|
||||
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
|
||||
OP_NAME##_aa(dst, src, num_elems); \
|
||||
} else if (hex_is_aligned((void *) dst, 128)) { \
|
||||
OP_NAME##_au(dst, src, num_elems); \
|
||||
} else if (hex_is_aligned((void *) src, 128)) { \
|
||||
OP_NAME##_ua(dst, src, num_elems); \
|
||||
} else { \
|
||||
OP_NAME##_uu(dst, src, num_elems); \
|
||||
} \
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body)
|
||||
DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body)
|
||||
|
||||
static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
|
||||
if ((unsigned long) dst % 128 == 0) {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_inverse_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_inverse_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_inverse_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_inverse_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
HVX_INV_DISPATCHER(hvx_inverse_f32)
|
||||
HVX_INV_DISPATCHER(hvx_inverse_f16)
|
||||
|
||||
#endif // HVX_INVERSE_H
|
||||
|
|
|
|||
|
|
@ -400,7 +400,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
const uint32_t ne0 = dst->ne[0];
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
|
@ -465,17 +467,14 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
rctx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
rctx.theta_cache_offset = theta_cache_size_aligned;
|
||||
|
||||
uint32_t ne0 = dst->ne[0];
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
rctx.src0_nrows = src0_nrows;
|
||||
rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
|
||||
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
|
||||
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
|
|
|||
|
|
@ -128,6 +128,8 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
|||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
set_rows_preamble;
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
|
@ -149,15 +151,14 @@ int op_set_rows(struct htp_ops_context * octx) {
|
|||
srctx.div_ne12 = init_fastdiv_values(ne12);
|
||||
srctx.div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
|
|
|||
|
|
@ -353,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
|
|
@ -393,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
|||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
|
|
|||
|
|
@ -102,11 +102,9 @@ int op_sum_rows(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
|
||||
bool opt_path = false;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
|
|
@ -124,7 +122,7 @@ int op_sum_rows(struct htp_ops_context * octx) {
|
|||
.opt_path = opt_path,
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -301,8 +301,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
|
@ -338,11 +338,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
struct htp_unary_context uctx = {
|
||||
.octx = octx,
|
||||
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
|
||||
.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
|
||||
.src0_nrows = src0_nrows,
|
||||
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
|
|
@ -361,7 +359,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
.nc = src0->ne[0],
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
|
|
|||
Loading…
Reference in New Issue